From 48bcb845c8dc7d48373f47a15ae09fc3e8094aaa Mon Sep 17 00:00:00 2001 From: Pramod Chunduri Date: Fri, 17 Nov 2023 13:10:50 -0500 Subject: [PATCH 1/4] Migrate ChatGPT function to openai v1.0 --- evadb/functions/chatgpt.py | 22 +++++++++++-------- setup.py | 2 +- .../long/functions/test_chatgpt.py | 17 ++++++++------ 3 files changed, 24 insertions(+), 17 deletions(-) diff --git a/evadb/functions/chatgpt.py b/evadb/functions/chatgpt.py index fadc61191a..e5d5e31c61 100644 --- a/evadb/functions/chatgpt.py +++ b/evadb/functions/chatgpt.py @@ -115,19 +115,23 @@ def setup( ) def forward(self, text_df): try_to_import_openai() - import openai + from openai import OpenAI - @retry(tries=6, delay=20) - def completion_with_backoff(**kwargs): - return openai.ChatCompletion.create(**kwargs) - - openai.api_key = self.openai_api_key - if len(openai.api_key) == 0: - openai.api_key = os.environ.get("OPENAI_API_KEY", "") + api_key = self.openai_api_key + if len(self.openai_api_key) == 0: + api_key = os.environ.get("OPENAI_API_KEY", "") assert ( - len(openai.api_key) != 0 + len(api_key) != 0 ), "Please set your OpenAI API key using SET OPENAI_API_KEY = 'sk-' or environment variable (OPENAI_API_KEY)" + client = OpenAI( + api_key=api_key + ) + + @retry(tries=6, delay=20) + def completion_with_backoff(**kwargs): + return client.chat.completions.create(**kwargs) + queries = text_df[text_df.columns[0]] content = text_df[text_df.columns[0]] if len(text_df.columns) > 1: diff --git a/setup.py b/setup.py index 3334fa8361..61dc0b8c66 100644 --- a/setup.py +++ b/setup.py @@ -80,7 +80,7 @@ def read(path, encoding="utf-8"): "sentence-transformers", "protobuf", "bs4", - "openai==0.28", # CHATGPT + "openai>=1.0", # CHATGPT "gpt4all", # PRIVATE GPT "sentencepiece", # TRANSFORMERS ] diff --git a/test/integration_tests/long/functions/test_chatgpt.py b/test/integration_tests/long/functions/test_chatgpt.py index b72612d050..8af9bed2e5 100644 --- a/test/integration_tests/long/functions/test_chatgpt.py +++ b/test/integration_tests/long/functions/test_chatgpt.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import unittest from test.markers import chatgpt_skip_marker from test.util import get_evadb_for_testing @@ -22,9 +23,8 @@ from evadb.server.command_handler import execute_query_fetch_all -def create_dummy_csv_file(config) -> str: - tmp_dir_from_config = config.get_value("storage", "tmp_dir") - +def create_dummy_csv_file(catalog) -> str: + tmp_dir_from_config = catalog.get_configuration_catalog_value("tmp_dir") df_dict = [ { "prompt": "summarize", @@ -49,17 +49,20 @@ def setUp(self) -> None: );""" execute_query_fetch_all(self.evadb, create_table_query) - self.csv_file_path = create_dummy_csv_file(self.evadb.config) + self.csv_file_path = create_dummy_csv_file(self.evadb.catalog()) csv_query = f"""LOAD CSV '{self.csv_file_path}' INTO MyTextCSV;""" execute_query_fetch_all(self.evadb, csv_query) + os.environ[ + "OPENAI_API_KEY" + ] = "sk-..." def tearDown(self) -> None: execute_query_fetch_all(self.evadb, "DROP TABLE IF EXISTS MyTextCSV;") - @chatgpt_skip_marker + # @chatgpt_skip_marker def test_openai_chat_completion_function(self): - function_name = "OpenAIChatCompletion" + function_name = "ChatGPT" execute_query_fetch_all(self.evadb, f"DROP FUNCTION IF EXISTS {function_name};") create_function_query = f"""CREATE FUNCTION IF NOT EXISTS{function_name} @@ -69,4 +72,4 @@ def test_openai_chat_completion_function(self): gpt_query = f"SELECT {function_name}('summarize', content) FROM MyTextCSV;" output_batch = execute_query_fetch_all(self.evadb, gpt_query) - self.assertEqual(output_batch.columns, ["openaichatcompletion.response"]) + self.assertEqual(output_batch.columns, ["chatgpt.response"]) From bf2e2ad36eacbaacf00b6194b18512b089cf3826 Mon Sep 17 00:00:00 2001 From: Pramod Chunduri <43007047+pchunduri6@users.noreply.github.com> Date: Fri, 17 Nov 2023 13:17:34 -0500 Subject: [PATCH 2/4] skip test --- test/integration_tests/long/functions/test_chatgpt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/integration_tests/long/functions/test_chatgpt.py b/test/integration_tests/long/functions/test_chatgpt.py index 8af9bed2e5..f7cc874304 100644 --- a/test/integration_tests/long/functions/test_chatgpt.py +++ b/test/integration_tests/long/functions/test_chatgpt.py @@ -60,7 +60,7 @@ def setUp(self) -> None: def tearDown(self) -> None: execute_query_fetch_all(self.evadb, "DROP TABLE IF EXISTS MyTextCSV;") - # @chatgpt_skip_marker + @chatgpt_skip_marker def test_openai_chat_completion_function(self): function_name = "ChatGPT" execute_query_fetch_all(self.evadb, f"DROP FUNCTION IF EXISTS {function_name};") From 21ec26464715b5effa29875184a0e97a3188ce40 Mon Sep 17 00:00:00 2001 From: Pramod Chunduri Date: Fri, 17 Nov 2023 15:46:22 -0500 Subject: [PATCH 3/4] upgrade dalle function --- evadb/functions/dalle.py | 19 +++++++++--------- test/unit_tests/test_dalle.py | 36 ++++++++++++++++++++++++++++++----- 2 files changed, 41 insertions(+), 14 deletions(-) diff --git a/evadb/functions/dalle.py b/evadb/functions/dalle.py index 7c1dc39dd0..03c2e77f88 100644 --- a/evadb/functions/dalle.py +++ b/evadb/functions/dalle.py @@ -56,24 +56,25 @@ def setup(self, openai_api_key="") -> None: ) def forward(self, text_df): try_to_import_openai() - import openai + from openai import OpenAI - openai.api_key = self.openai_api_key - # If not found, try OS Environment Variable - if len(openai.api_key) == 0: - openai.api_key = os.environ.get("OPENAI_API_KEY", "") + api_key = self.openai_api_key + if len(self.openai_api_key) == 0: + api_key = os.environ.get("OPENAI_API_KEY", "") assert ( - len(openai.api_key) != 0 - ), "Please set your OpenAI API key using SET OPENAI_API_KEY = 'sk-' or environment variable (OPENAI_API_KEY)" + len(api_key) != 0 + ), "Please set your OpenAI API key using SET OPENAI_API_KEY = 'sk-' or environment variable (OPENAI_API_KEY)" + + client = OpenAI(api_key=api_key) def generate_image(text_df: PandasDataframe): results = [] queries = text_df[text_df.columns[0]] for query in queries: - response = openai.Image.create(prompt=query, n=1, size="1024x1024") + response = client.images.generate(prompt=query, n=1, size="1024x1024") # Download the image from the link - image_response = requests.get(response["data"][0]["url"]) + image_response = requests.get(response.data[0].url) image = Image.open(BytesIO(image_response.content)) # Convert the image to an array format suitable for the DataFrame diff --git a/test/unit_tests/test_dalle.py b/test/unit_tests/test_dalle.py index c434a4db4a..4e877c15fc 100644 --- a/test/unit_tests/test_dalle.py +++ b/test/unit_tests/test_dalle.py @@ -18,10 +18,24 @@ from test.util import get_evadb_for_testing from unittest.mock import MagicMock, patch -from PIL import Image +from PIL import Image as PILImage from evadb.server.command_handler import execute_query_fetch_all +from typing import List, Optional +from pydantic import BaseModel, AnyUrl + + +class Image(BaseModel): + b64_json: Optional[str] # Replace with the actual type if different + revised_prompt: Optional[str] # Replace with the actual type if different + url: AnyUrl + + +class ImagesResponse(BaseModel): + created: Optional[int] # Replace with the actual type if different + data: List[Image] + class DallEFunctionTest(unittest.TestCase): def setUp(self) -> None: @@ -43,10 +57,10 @@ def tearDown(self) -> None: @patch.dict("os.environ", {"OPENAI_API_KEY": "mocked_openai_key"}) @patch("requests.get") - @patch("openai.Image.create", return_value={"data": [{"url": "mocked_url"}]}) - def test_dalle_image_generation(self, mock_openai_create, mock_requests_get): + @patch("openai.OpenAI") + def test_dalle_image_generation(self, mock_openai, mock_requests_get): # Generate a 1x1 white pixel PNG image in memory - img = Image.new("RGB", (1, 1), color="white") + img = PILImage.new("RGB", (1, 1), color="white") img_byte_array = BytesIO() img.save(img_byte_array, format="PNG") mock_image_content = img_byte_array.getvalue() @@ -55,6 +69,18 @@ def test_dalle_image_generation(self, mock_openai_create, mock_requests_get): mock_response.content = mock_image_content mock_requests_get.return_value = mock_response + # Set up the mock for OpenAI instance + mock_openai_instance = mock_openai.return_value + mock_openai_instance.images.generate.return_value = ImagesResponse( + data=[ + Image( + b64_json=None, + revised_prompt=None, + url="https://images.openai.com/1234.png" + ) + ] + ) + function_name = "DallE" execute_query_fetch_all(self.evadb, f"DROP FUNCTION IF EXISTS {function_name};") @@ -67,6 +93,6 @@ def test_dalle_image_generation(self, mock_openai_create, mock_requests_get): gpt_query = f"SELECT {function_name}(prompt) FROM ImageGen;" execute_query_fetch_all(self.evadb, gpt_query) - mock_openai_create.assert_called_once_with( + mock_openai_instance.images.generate.assert_called_once_with( prompt="a surreal painting of a cat", n=1, size="1024x1024" ) From 2ce9275e432c5ce05c01c914da571295f128a182 Mon Sep 17 00:00:00 2001 From: Pramod Chunduri Date: Fri, 17 Nov 2023 16:17:48 -0500 Subject: [PATCH 4/4] fix linting --- evadb/functions/chatgpt.py | 4 +--- test/integration_tests/long/functions/test_chatgpt.py | 4 +--- test/unit_tests/test_dalle.py | 7 +++---- 3 files changed, 5 insertions(+), 10 deletions(-) diff --git a/evadb/functions/chatgpt.py b/evadb/functions/chatgpt.py index e5d5e31c61..bf0d338689 100644 --- a/evadb/functions/chatgpt.py +++ b/evadb/functions/chatgpt.py @@ -124,9 +124,7 @@ def forward(self, text_df): len(api_key) != 0 ), "Please set your OpenAI API key using SET OPENAI_API_KEY = 'sk-' or environment variable (OPENAI_API_KEY)" - client = OpenAI( - api_key=api_key - ) + client = OpenAI(api_key=api_key) @retry(tries=6, delay=20) def completion_with_backoff(**kwargs): diff --git a/test/integration_tests/long/functions/test_chatgpt.py b/test/integration_tests/long/functions/test_chatgpt.py index f7cc874304..3f8cd9a92e 100644 --- a/test/integration_tests/long/functions/test_chatgpt.py +++ b/test/integration_tests/long/functions/test_chatgpt.py @@ -53,9 +53,7 @@ def setUp(self) -> None: csv_query = f"""LOAD CSV '{self.csv_file_path}' INTO MyTextCSV;""" execute_query_fetch_all(self.evadb, csv_query) - os.environ[ - "OPENAI_API_KEY" - ] = "sk-..." + os.environ["OPENAI_API_KEY"] = "sk-..." def tearDown(self) -> None: execute_query_fetch_all(self.evadb, "DROP TABLE IF EXISTS MyTextCSV;") diff --git a/test/unit_tests/test_dalle.py b/test/unit_tests/test_dalle.py index 4e877c15fc..373e90e44c 100644 --- a/test/unit_tests/test_dalle.py +++ b/test/unit_tests/test_dalle.py @@ -16,15 +16,14 @@ import unittest from io import BytesIO from test.util import get_evadb_for_testing +from typing import List, Optional from unittest.mock import MagicMock, patch from PIL import Image as PILImage +from pydantic import AnyUrl, BaseModel from evadb.server.command_handler import execute_query_fetch_all -from typing import List, Optional -from pydantic import BaseModel, AnyUrl - class Image(BaseModel): b64_json: Optional[str] # Replace with the actual type if different @@ -76,7 +75,7 @@ def test_dalle_image_generation(self, mock_openai, mock_requests_get): Image( b64_json=None, revised_prompt=None, - url="https://images.openai.com/1234.png" + url="https://images.openai.com/1234.png", ) ] )