Skip to content

Commit

Permalink
Merge pull request #180 from geroldmeisinger/notify_when_finished
Browse files Browse the repository at this point in the history
Add option to show an alert when auto-captioning is finished
  • Loading branch information
jhc13 authored Jun 13, 2024
2 parents 105f6d8 + 7dbd1e9 commit eb1a816
Show file tree
Hide file tree
Showing 8 changed files with 74 additions and 18 deletions.
3 changes: 2 additions & 1 deletion taggui-linux.spec
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# -*- mode: python ; coding: utf-8 -*-
from PyInstaller.utils.hooks import collect_data_files

datas = [('clip-vit-base-patch32', 'clip-vit-base-patch32'), ('images/icon.ico', 'images')]
datas = [('clip-vit-base-patch32', 'clip-vit-base-patch32'),
('images/icon.ico', 'images')]
datas += [('/usr/include/python3.11', 'include/python3.11')]
datas += collect_data_files('triton')
datas += collect_data_files('xformers')
Expand Down
3 changes: 2 additions & 1 deletion taggui-windows.spec
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# -*- mode: python ; coding: utf-8 -*-
from PyInstaller.utils.hooks import collect_data_files

datas = [('clip-vit-base-patch32', 'clip-vit-base-patch32'), ('images/icon.ico', 'images')]
datas = [('clip-vit-base-patch32', 'clip-vit-base-patch32'),
('images/icon.ico', 'images')]
datas += collect_data_files('xformers')

block_cipher = None
Expand Down
13 changes: 12 additions & 1 deletion taggui/auto_captioning/captioning_thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ def __init__(self, parent, image_list_model: ImageListModel,
self.caption_settings = caption_settings
self.tag_separator = tag_separator
self.models_directory_path = models_directory_path
self.is_error = False
self.is_canceled = False

def load_processor_and_model(self, device: torch.device,
Expand Down Expand Up @@ -357,7 +358,7 @@ def get_caption_from_generated_tokens(
caption = caption.replace(self.tag_separator, ' ')
return caption

def run(self):
def run_captioning(self):
model_id = self.caption_settings['model']
model_type = get_model_type(model_id)
forced_words_string = self.caption_settings['forced_words']
Expand All @@ -366,6 +367,7 @@ def run(self):
beam_count = generation_parameters['num_beams']
if (forced_words_string.strip() and beam_count < 2
and model_type != CaptionModelType.WD_TAGGER):
self.is_error = True
self.clear_console_text_edit_requested.emit()
print('`Number of beams` must be greater than 1 when `Include in '
'caption` is not empty.')
Expand All @@ -392,6 +394,7 @@ def run(self):
error_message = get_xcomposer2_error_message(
model_id, self.caption_settings['device'], load_in_4_bit)
if error_message:
self.is_error = True
self.clear_console_text_edit_requested.emit()
print(error_message)
return
Expand Down Expand Up @@ -507,5 +510,13 @@ def run(self):
f'({average_captioning_duration:.1f} s/image) at '
f'{captioning_end_datetime.strftime("%Y-%m-%d %H:%M:%S")}.')

def run(self):
try:
self.run_captioning()
except Exception as exception:
self.is_error = True
# Show the error message in the console text edit.
raise exception

def write(self, text: str):
self.text_outputted.emit(text)
16 changes: 16 additions & 0 deletions taggui/dialogs/caption_multiple_images_dialog.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from utils.settings_widgets import SettingsBigCheckBox
from utils.utils import ConfirmationDialog


class CaptionMultipleImagesDialog(ConfirmationDialog):
def __init__(self, selected_image_count: int):
title = 'Generate Captions'
question = f'Caption {selected_image_count} selected images?'
super().__init__(title=title, question=question)
self.show_alert_check_box = SettingsBigCheckBox(
key='show_alert_when_captioning_finished', default=True,
text='Show alert when finished')
self.setCheckBox(self.show_alert_check_box)
layout = self.layout()
layout.setContentsMargins(20, 20, 20, 20)
layout.setSpacing(20)
4 changes: 2 additions & 2 deletions taggui/utils/big_widgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ def __init__(self, text: str):


class BigCheckBox(QCheckBox):
def __init__(self):
super().__init__()
def __init__(self, text: str | None = None):
super().__init__(text)
settings = get_settings()
font_size = settings.value(
'font_size', defaultValue=DEFAULT_SETTINGS['font_size'], type=int)
Expand Down
4 changes: 2 additions & 2 deletions taggui/utils/settings_widgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@


class SettingsBigCheckBox(BigCheckBox):
def __init__(self, key: str, default: bool):
super().__init__()
def __init__(self, key: str, default: bool, text: str | None = None):
super().__init__(text)
settings = get_settings()
self.setChecked(settings.value(key, default, type=bool))
self.stateChanged.connect(
Expand Down
19 changes: 12 additions & 7 deletions taggui/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,18 @@ def list_with_and(items: list[str]) -> str:
return ', '.join(items[:-1]) + f', and {items[-1]}'


class ConfirmationDialog(QMessageBox):
def __init__(self, title: str, question: str):
super().__init__()
self.setWindowTitle(title)
self.setIcon(QMessageBox.Icon.Question)
self.setText(question)
self.setStandardButtons(QMessageBox.StandardButton.Yes
| QMessageBox.StandardButton.Cancel)
self.setDefaultButton(QMessageBox.StandardButton.Yes)


def get_confirmation_dialog_reply(title: str, question: str) -> int:
"""Display a confirmation dialog and return the user's reply."""
confirmation_dialog = QMessageBox()
confirmation_dialog.setWindowTitle(title)
confirmation_dialog.setIcon(QMessageBox.Icon.Question)
confirmation_dialog.setText(question)
confirmation_dialog.setStandardButtons(QMessageBox.StandardButton.Yes
| QMessageBox.StandardButton.Cancel)
confirmation_dialog.setDefaultButton(QMessageBox.StandardButton.Yes)
confirmation_dialog = ConfirmationDialog(title, question)
return confirmation_dialog.exec()
30 changes: 26 additions & 4 deletions taggui/widgets/auto_captioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from auto_captioning.captioning_thread import CaptioningThread
from auto_captioning.models import MODELS, get_model_type
from dialogs.caption_multiple_images_dialog import CaptionMultipleImagesDialog
from models.image_list_model import ImageListModel
from utils.big_widgets import TallPushButton
from utils.enums import CaptionDevice, CaptionModelType, CaptionPosition
Expand All @@ -19,7 +20,7 @@
FocusedScrollSettingsSpinBox,
SettingsBigCheckBox, SettingsLineEdit,
SettingsPlainTextEdit)
from utils.utils import get_confirmation_dialog_reply, pluralize
from utils.utils import pluralize
from widgets.image_list import ImageList


Expand Down Expand Up @@ -439,16 +440,35 @@ def update_console_text_edit(self, text: str):
self.console_text_edit.textCursor().deletePreviousChar()
self.console_text_edit.appendPlainText(text)

@Slot()
def show_alert(self):
if self.captioning_thread.is_canceled:
return
if self.captioning_thread.is_error:
icon = QMessageBox.Icon.Critical
text = ('An error occurred during captioning. See the '
'Auto-Captioner console for more information.')
else:
icon = QMessageBox.Icon.Information
text = 'Captioning has finished.'
alert = QMessageBox()
alert.setIcon(icon)
alert.setText(text)
alert.exec()

@Slot()
def generate_captions(self):
selected_image_indices = self.image_list.get_selected_image_indices()
selected_image_count = len(selected_image_indices)
show_alert_when_finished = False
if selected_image_count > 1:
reply = get_confirmation_dialog_reply(
title='Generate Captions',
question=f'Caption {selected_image_count} selected images?')
confirmation_dialog = CaptionMultipleImagesDialog(
selected_image_count)
reply = confirmation_dialog.exec()
if reply != QMessageBox.StandardButton.Yes:
return
show_alert_when_finished = (confirmation_dialog
.show_alert_check_box.isChecked())
self.set_is_captioning(True)
caption_settings = self.caption_settings_form.get_caption_settings()
if caption_settings['caption_position'] != CaptionPosition.DO_NOT_ADD:
Expand Down Expand Up @@ -483,6 +503,8 @@ def generate_captions(self):
self.captioning_thread.finished.connect(self.progress_bar.hide)
self.captioning_thread.finished.connect(
lambda: self.start_cancel_button.setEnabled(True))
if show_alert_when_finished:
self.captioning_thread.finished.connect(self.show_alert)
# Redirect `stdout` and `stderr` so that the outputs are displayed in
# the console text edit.
sys.stdout = self.captioning_thread
Expand Down

0 comments on commit eb1a816

Please sign in to comment.