Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

9 image generation from text prompts #13

Merged
merged 11 commits into from
Sep 12, 2023
16 changes: 3 additions & 13 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ jobs:

strategy:
matrix:
python-version: ["3.8", "3.9", "3.10", "3.11"]
python-version: ["3.8", "3.9", "3.10"]

steps:
- uses: actions/checkout@v2
Expand All @@ -28,18 +28,8 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install flake8 pytest
pip install ruff
pip install -r requirements.txt
- name: Run tests with pytest
run: |
pytest tests
- name: Run lintering with ruff
run: |
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics

ruff src --fix # make ruff run for src folder files and --fix to fix automatically linting errors
ruff tests --fix
pytest -m "not apitest"

14 changes: 14 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,18 @@
aiohttp==3.8.5
aiosignal==1.3.1
async-timeout==4.0.3
attrs==23.1.0
certifi==2023.7.22
charset-normalizer==3.2.0
colorama==0.4.6
frozenlist==1.4.0
idna==3.4
multidict==6.0.4
openai==0.28.0
requests==2.31.0
tqdm==4.66.1
urllib3==2.0.4
yarl==1.9.2
exceptiongroup==1.1.3
iniconfig==2.0.0
packaging==23.1
Expand Down
63 changes: 63 additions & 0 deletions src/image_generation/image_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# The class is a wrapper class for generating images from a prompt
# The class uses the OpenAI API to generate images

import openai

class ImageGenerator:
MAX_PROMPT_LENGTH = 1000
SUPPORTED_SIZES = [256, 512, 1024]

def get_user_prompt(self) -> str:
image_prompt = input('What do you want to generate an image of? ')
self._validate_prompt(image_prompt)
return image_prompt

def _validate_prompt(self, image_prompt: str) -> None:
if not isinstance(image_prompt, str):
raise TypeError('Prompt must be a string')

if len(image_prompt) == 0:
raise ValueError('Prompt must not be empty')

if (len(image_prompt) > self.MAX_PROMPT_LENGTH):
raise ValueError(f'Prompt must be less than {self.MAX_PROMPT_LENGTH} characters')

def _validate_size(self, width: int, height: int) -> None:
if not isinstance(width, int):
raise TypeError('Width must be an integer')

if not isinstance(height, int):
raise TypeError('Height must be an integer')

if (width != height):
raise ValueError('Width and height must be equal')

if (width not in self.SUPPORTED_SIZES):
raise ValueError(f'Width must be {self.SUPPORTED_SIZES}')

def generate_image(self, image_prompt: str, width: int, height: int) -> dict:
'''
Generates an image from a prompt of a certain size
Args:
image_prompt (str): The prompt to generate the image from. Must be less than 1000 characters.
width (int): The width of the image. Must be 256, 512, or 1024
height (int): The height of the image. Must be 256, 512, or 1024
Returns:
A dictionary containing the image url
'''
self._validate_prompt(image_prompt)
self._validate_size(width, height)
size = str(width) + 'x' + str(height)
openai.api_key = '' # TODO: Make sure this is the way we get environment variables
response = openai.Image.create(
prompt=image_prompt,
n=1,
size=size
)
return response

if __name__ == "__main__":
image_generator = ImageGenerator()
image_prompt = image_generator.get_user_prompt()
response = image_generator.generate_image(image_prompt, 512, 512)
print(response['data'][0]['url'])
150 changes: 150 additions & 0 deletions tests/image_generation/image_generator_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
import pytest

from src.image_generation.image_generator import ImageGenerator

def test_image_generation_with_empty_string_prompt():
# Arrange
image_generator = ImageGenerator()
no_prompt = ""
with pytest.raises(ValueError):
# Act
result = image_generator.generate_image(no_prompt, 1, 1)

def test_image_generation_with_non_string_prompt():
# Arrange
image_generator = ImageGenerator()
non_string_prompt = 1
with pytest.raises(TypeError):
# Act
result = image_generator.generate_image(non_string_prompt, 1, 1)

def test_image_generation_with_none_prompt():
# Arrange
image_generator = ImageGenerator()
none_prompt = None
with pytest.raises(TypeError):
# Act
result = image_generator.generate_image(none_prompt, 1, 1)


def test_image_generation_with_too_long_prompt():
# Arrange
image_generator = ImageGenerator()
too_long_prompt = "A" * 3001
with pytest.raises(ValueError):
# Act
result = image_generator.generate_image(too_long_prompt, 1, 1)

def test_image_generation_with_non_integer_width():
# Arrange
image_generator = ImageGenerator()
non_integer_width = "1"
valid_height = 1
valid_prompt = "A picture of a cat"
with pytest.raises(TypeError):
# Act
result = image_generator.generate_image(valid_prompt, non_integer_width, valid_height)

def test_image_generation_with_non_integer_height():
# Arrange
image_generator = ImageGenerator()
valid_width = 1
non_integer_height = "1"
valid_prompt = "A picture of a cat"
with pytest.raises(TypeError):
# Act
result = image_generator.generate_image(valid_prompt, valid_width, non_integer_height)
# Assert

def test_image_generation_with_too_small_width():
# Arrange

image_generator = ImageGenerator()
too_small_width = 0
valid_height = 1
valid_prompt = "A picture of a cat"
with pytest.raises(ValueError):
# Act
result = image_generator.generate_image(valid_prompt, too_small_width, valid_height)
# Assert

def test_image_generation_with_too_small_height():
# Arrange

image_generator = ImageGenerator()
valid_width = 1
too_small_height = 0
valid_prompt = "A picture of a cat"
with pytest.raises(ValueError):
# Act
result = image_generator.generate_image(valid_prompt, valid_width, too_small_height)
# Assert

def test_image_generation_with_too_large_width():
# Arrange

image_generator = ImageGenerator()
too_large_width = 1025
valid_height = 1
valid_prompt = "A picture of a cat"
with pytest.raises(ValueError):
# Act
result = image_generator.generate_image(valid_prompt, too_large_width, valid_height)
# Assert

def test_image_generation_with_too_large_height():
# Arrange

image_generator = ImageGenerator()
valid_width = 1
too_large_height = 1025
valid_prompt = "A picture of a cat"
with pytest.raises(ValueError):
# Act
result = image_generator.generate_image(valid_prompt, valid_width, too_large_height)
# Assert

def test_image_generation_with_width_not_equal_to_height():
# Arrange

image_generator = ImageGenerator()
width_not_equal_to_height = 512
valid_prompt = "A picture of a cat"
with pytest.raises(ValueError):
# Act
result = image_generator.generate_image(valid_prompt, width_not_equal_to_height, width_not_equal_to_height + 1)
# Assert

def test_image_generation_with_invalid_width():
# Arrange

image_generator = ImageGenerator()
invalid_width = 300
valid_prompt = "A picture of a cat"
with pytest.raises(ValueError):
# Act
result = image_generator.generate_image(valid_prompt, invalid_width, invalid_width)
# Assert

def test_image_generation_with_invalid_height():
# Arrange

image_generator = ImageGenerator()
invalid_height = 300
valid_prompt = "A picture of a cat"
with pytest.raises(ValueError):
# Act
result = image_generator.generate_image(valid_prompt, invalid_height, invalid_height)
# Assert

@pytest.mark.apitest
def test_image_generation_with_valid_prompt():
# Arrange

image_generator = ImageGenerator()
valid_size = 256
valid_prompt = "A picture of a cat"
# Act
result = image_generator.generate_image(valid_prompt, valid_size, valid_size)
# Assert
assert result is not None
Loading