Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add stable diffusion integration #1240

Merged
merged 18 commits into from
Oct 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/_toc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ parts:
title: OpenAI
- file: source/reference/ai/yolo
title: YOLO
- file: source/reference/ai/stablediffusion
title: Stable Diffusion

- file: source/reference/ai/custom-ai-function
title: Bring Your Own AI Function
Expand Down
27 changes: 27 additions & 0 deletions docs/source/reference/ai/stablediffusion.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
.. _stablediffusion:

Stable Diffusion Models
======================================

This section provides an overview of how you can generate images from prompts in EvaDB using a Stable Diffusion model.


Introduction
------------

Stable Diffusion models leverage a controlled random walk process to generate intricate patterns and images from textual prompts,
bridging the gap between text and visual representation. EvaDB uses the stable diffusion implementation from `Replicate <https://replicate.com>`_.

Stable Diffusion UDF
--------------------

In order to create an image generation function in EvaDB, use the following SQL command:

.. code-block:: sql

CREATE FUNCTION IF NOT EXISTS StableDiffusion
IMPL 'evadb/functions/stable_diffusion.py';

EvaDB automatically uses the latest `stable diffusion release <https://replicate.com/stability-ai/stable-diffusion/versions>`_ available on Replicate.

To see a demo of how the function can be used, please check the `demo notebook <https://colab.research.google.com/github/georgia-tech-db/eva/blob/master/tutorials/18-stable-diffusion.ipynb>`_ on stable diffusion.
1 change: 1 addition & 0 deletions evadb/evadb.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,4 @@ third_party:
OPENAI_KEY: ""
PINECONE_API_KEY: ""
PINECONE_ENV: ""
REPLICATE_API_TOKEN: ""
88 changes: 88 additions & 0 deletions evadb/functions/dalle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# coding=utf-8
# Copyright 2018-2023 EvaDB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from io import BytesIO

import numpy as np
import pandas as pd
import requests
from PIL import Image

from evadb.catalog.catalog_type import NdArrayType
from evadb.configuration.configuration_manager import ConfigurationManager
from evadb.functions.abstract.abstract_function import AbstractFunction
from evadb.functions.decorators.decorators import forward
from evadb.functions.decorators.io_descriptors.data_types import PandasDataframe
from evadb.utils.generic_utils import try_to_import_openai


class DallEFunction(AbstractFunction):
@property
def name(self) -> str:
return "DallE"

def setup(self) -> None:
pass

@forward(
input_signatures=[
PandasDataframe(
columns=["prompt"],
column_types=[
NdArrayType.STR,
],
column_shapes=[(None,)],
)
],
output_signatures=[
PandasDataframe(
columns=["response"],
column_types=[NdArrayType.FLOAT32],
column_shapes=[(None, None, 3)],
)
],
)
def forward(self, text_df):
try_to_import_openai()
import openai

# Register API key, try configuration manager first
openai.api_key = ConfigurationManager().get_value("third_party", "OPENAI_KEY")
# If not found, try OS Environment Variable
if len(openai.api_key) == 0:
openai.api_key = os.environ.get("OPENAI_KEY", "")
assert (
len(openai.api_key) != 0
), "Please set your OpenAI API key in evadb.yml file (third_party, open_api_key) or environment variable (OPENAI_KEY)"

def generate_image(text_df: PandasDataframe):
results = []
queries = text_df[text_df.columns[0]]
for query in queries:
response = openai.Image.create(prompt=query, n=1, size="1024x1024")

# Download the image from the link
image_response = requests.get(response["data"][0]["url"])
image = Image.open(BytesIO(image_response.content))

# Convert the image to an array format suitable for the DataFrame
frame = np.array(image)
results.append(frame)

return results

df = pd.DataFrame({"response": generate_image(text_df=text_df)})
return df
14 changes: 14 additions & 0 deletions evadb/functions/function_bootstrap_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,18 @@
MODEL 'yolov8n.pt';
"""

stablediffusion_function_query = """CREATE FUNCTION IF NOT EXISTS StableDiffusion
IMPL '{}/functions/stable_diffusion.py';
""".format(
EvaDB_INSTALLATION_DIR
)

dalle_function_query = """CREATE FUNCTION IF NOT EXISTS DallE
IMPL '{}/functions/dalle.py';
""".format(
EvaDB_INSTALLATION_DIR
)


def init_builtin_functions(db: EvaDBDatabase, mode: str = "debug") -> None:
"""Load the built-in functions into the system during system bootstrapping.
Expand Down Expand Up @@ -247,6 +259,8 @@ def init_builtin_functions(db: EvaDBDatabase, mode: str = "debug") -> None:
# Mvit_function_query,
Sift_function_query,
Yolo_function_query,
stablediffusion_function_query,
dalle_function_query,
]

# if mode is 'debug', add debug functions
Expand Down
102 changes: 102 additions & 0 deletions evadb/functions/stable_diffusion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# coding=utf-8
# Copyright 2018-2023 EvaDB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from io import BytesIO

import numpy as np
import pandas as pd
import requests
from PIL import Image

from evadb.catalog.catalog_type import NdArrayType
from evadb.configuration.configuration_manager import ConfigurationManager
from evadb.functions.abstract.abstract_function import AbstractFunction
from evadb.functions.decorators.decorators import forward
from evadb.functions.decorators.io_descriptors.data_types import PandasDataframe
from evadb.utils.generic_utils import try_to_import_replicate


class StableDiffusion(AbstractFunction):
@property
def name(self) -> str:
return "StableDiffusion"

def setup(
self,
) -> None:
pass

@forward(
input_signatures=[
PandasDataframe(
columns=["prompt"],
column_types=[
NdArrayType.STR,
],
column_shapes=[(None,)],
)
],
output_signatures=[
PandasDataframe(
columns=["response"],
column_types=[
# FileFormatType.IMAGE,
NdArrayType.FLOAT32
],
column_shapes=[(None, None, 3)],
)
],
)
def forward(self, text_df):
try_to_import_replicate()
import replicate

# Register API key, try configuration manager first
replicate_api_key = ConfigurationManager().get_value(
"third_party", "REPLICATE_API_TOKEN"
)
# If not found, try OS Environment Variable
if len(replicate_api_key) == 0:
replicate_api_key = os.environ.get("REPLICATE_API_TOKEN", "")
assert (
len(replicate_api_key) != 0
), "Please set your Replicate API key in evadb.yml file (third_party, replicate_api_token) or environment variable (REPLICATE_API_TOKEN)"
os.environ["REPLICATE_API_TOKEN"] = replicate_api_key

model_id = (
replicate.models.get("stability-ai/stable-diffusion").versions.list()[0].id
)

def generate_image(text_df: PandasDataframe):
results = []
queries = text_df[text_df.columns[0]]
for query in queries:
output = replicate.run(
"stability-ai/stable-diffusion:" + model_id, input={"prompt": query}
)

# Download the image from the link
response = requests.get(output[0])
image = Image.open(BytesIO(response.content))

# Convert the image to an array format suitable for the DataFrame
frame = np.array(image)
results.append(frame)

return results

df = pd.DataFrame({"response": generate_image(text_df=text_df)})
return df
18 changes: 18 additions & 0 deletions evadb/utils/generic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,3 +629,21 @@ def string_comparison_case_insensitive(string_1, string_2) -> bool:
return False

return string_1.lower() == string_2.lower()


def try_to_import_replicate():
try:
import replicate # noqa: F401
except ImportError:
raise ValueError(
"""Could not import replicate python package.
Please install it with `pip install replicate`."""
)


def is_replicate_available():
try:
try_to_import_replicate()
return True
except ValueError:
return False
2 changes: 1 addition & 1 deletion script/test/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ long_integration_test() {
}

notebook_test() {
PYTHONPATH=./ python -m pytest --durations=5 --nbmake --overwrite "./tutorials" --capture=sys --tb=short -v --log-level=WARNING --nbmake-timeout=3000 --ignore="tutorials/08-chatgpt.ipynb" --ignore="tutorials/14-food-review-tone-analysis-and-response.ipynb" --ignore="tutorials/15-AI-powered-join.ipynb" --ignore="tutorials/16-homesale-forecasting.ipynb" --ignore="tutorials/17-home-rental-prediction.ipynb"
PYTHONPATH=./ python -m pytest --durations=5 --nbmake --overwrite "./tutorials" --capture=sys --tb=short -v --log-level=WARNING --nbmake-timeout=3000 --ignore="tutorials/08-chatgpt.ipynb" --ignore="tutorials/14-food-review-tone-analysis-and-response.ipynb" --ignore="tutorials/15-AI-powered-join.ipynb" --ignore="tutorials/16-homesale-forecasting.ipynb" --ignore="tutorials/17-home-rental-prediction.ipynb" --ignore="tutorials/18-stable-diffusion.ipynb"
code=$?
print_error_code $code "NOTEBOOK TEST"
}
Expand Down
6 changes: 5 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,10 @@ def read(path, encoding="utf-8"):
"neuralforecast" # MODEL TRAIN AND FINE TUNING
]

imagegen_libs = [
"replicate"
]

### NEEDED FOR DEVELOPER TESTING ONLY

dev_libs = [
Expand Down Expand Up @@ -167,7 +171,7 @@ def read(path, encoding="utf-8"):
"sklearn": sklearn_libs,
"forecasting": forecasting_libs,
# everything except ray, qdrant, ludwig and postgres. The first three fail on pyhton 3.11.
"dev": dev_libs + vision_libs + document_libs + function_libs + notebook_libs + forecasting_libs + sklearn_libs,
"dev": dev_libs + vision_libs + document_libs + function_libs + notebook_libs + forecasting_libs + sklearn_libs + imagegen_libs,
}

setup(
Expand Down
68 changes: 68 additions & 0 deletions test/integration_tests/long/functions/test_stablediffusion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# coding=utf-8
# Copyright 2018-2023 EvaDB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
from test.markers import stable_diffusion_skip_marker
from test.util import get_evadb_for_testing

import numpy as np
import pytest

from evadb.server.command_handler import execute_query_fetch_all


class StableDiffusionTest(unittest.TestCase):
def setUp(self) -> None:
self.evadb = get_evadb_for_testing()
self.evadb.catalog().reset()
create_table_query = """CREATE TABLE IF NOT EXISTS ImageGen (
prompt TEXT);
"""
execute_query_fetch_all(self.evadb, create_table_query)

test_prompts = ["pink cat riding a rocket to the moon"]

for prompt in test_prompts:
insert_query = f"""INSERT INTO ImageGen (prompt) VALUES ('{prompt}')"""
execute_query_fetch_all(self.evadb, insert_query)

def tearDown(self) -> None:
execute_query_fetch_all(self.evadb, "DROP TABLE IF EXISTS ImageGen;")

@stable_diffusion_skip_marker
@pytest.mark.xfail(
reason="API call might be flaky due to rate limits or other issues."
)
def test_stable_diffusion_image_generation(self):
xzdandy marked this conversation as resolved.
Show resolved Hide resolved
function_name = "StableDiffusion"

execute_query_fetch_all(self.evadb, f"DROP FUNCTION IF EXISTS {function_name};")

create_function_query = f"""CREATE FUNCTION IF NOT EXISTS {function_name}
IMPL 'evadb/functions/stable_diffusion.py';
"""
execute_query_fetch_all(self.evadb, create_function_query)

gpt_query = f"SELECT {function_name}(prompt) FROM ImageGen;"
output_batch = execute_query_fetch_all(self.evadb, gpt_query)

self.assertEqual(output_batch.columns, ["stablediffusion.response"])

# Check if the returned data is an np.array representing an image
img_data = output_batch.frames["stablediffusion.response"][0]
self.assertIsInstance(img_data, np.ndarray)
self.assertEqual(
img_data.shape[2], 3
) # Check if the image has 3 channels (RGB)
5 changes: 5 additions & 0 deletions test/markers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
is_ludwig_available,
is_pinecone_available,
is_qdrant_available,
is_replicate_available,
is_sklearn_available,
)

Expand Down Expand Up @@ -96,3 +97,7 @@
is_forecast_available() is False,
reason="Run only if forecasting packages available",
)

stable_diffusion_skip_marker = pytest.mark.skipif(
is_replicate_available() is False, reason="requires replicate"
)
Loading