forked from Trojaner/text-generation-webui-stable_diffusion
-
Notifications
You must be signed in to change notification settings - Fork 0
/
script.py
303 lines (230 loc) · 8.79 KB
/
script.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
import html
import re
from dataclasses import asdict
from os import path
from typing import Any, List
from json_schema_logits_processor.json_schema_logits_processor import (
JsonSchemaLogitsProcessor,
)
from json_schema_logits_processor.schema.interative_schema import (
parse_schema_from_string,
)
from llama_cpp import LogitsProcessor
from transformers import PreTrainedTokenizer
from modules import chat, shared
from modules.logging_colors import logger
from .context import GenerationContext, get_current_context, set_current_context
from .ext_modules.image_generator import generate_html_images_for_context
from .ext_modules.text_analyzer import try_get_description_prompt
from .params import (
InteractiveModePromptGenerationMode,
StableDiffusionWebUiExtensionParams,
TriggerMode,
)
from .sd_client import SdWebUIApi
from .ui import render_ui
ui_params: Any = StableDiffusionWebUiExtensionParams()
params = asdict(ui_params)
context: GenerationContext | None = None
picture_processing_message = "*Is sending a picture...*"
default_processing_message = shared.processing_message
cached_schema_text: str | None = None
cached_schema_logits: LogitsProcessor | None = None
EXTENSION_DIRECTORY_NAME = path.basename(path.dirname(path.realpath(__file__)))
def get_or_create_context(state: dict | None = None) -> GenerationContext:
global context, params, ui_params
for key in ui_params.__dict__:
params[key] = ui_params.__dict__[key]
sd_client = SdWebUIApi(
baseurl=params["api_endpoint"],
username=params["api_username"],
password=params["api_password"],
)
if context is not None and not context.is_completed:
context.state = (context.state or {}) | (state or {})
context.sd_client = sd_client
return context
ext_params = StableDiffusionWebUiExtensionParams(**params)
ext_params.normalize()
context = (
GenerationContext(
params=ext_params,
sd_client=sd_client,
input_text=None,
state=state or {},
)
if context is None or context.is_completed
else context
)
set_current_context(context)
return context
def custom_generate_chat_prompt(text: str, state: dict, **kwargs: dict) -> str:
"""
Modifies the user input string in chat mode (visible_text).
You can also modify the internal representation of the user
input (text) to change how it will appear in the prompt.
"""
# bug: this does not trigger on regeneration and hence
# no context is created in that case
prompt: str = chat.generate_chat_prompt(text, state, **kwargs) # type: ignore
input_text = text
context = get_or_create_context(state)
context.input_text = input_text
context.state = state
if (
context is not None and not context.is_completed
) or context.params.trigger_mode == TriggerMode.MANUAL:
# A manual trigger was used
return prompt
if context.params.trigger_mode == TriggerMode.INTERACTIVE:
description_prompt = try_get_description_prompt(text, context.params)
if description_prompt is False:
# did not match image trigger
return prompt
assert isinstance(description_prompt, str)
prompt = (
description_prompt
if context.params.interactive_mode_prompt_generation_mode
== InteractiveModePromptGenerationMode.DYNAMIC
else text
)
return prompt
def state_modifier(state: dict) -> dict:
"""
Modifies the state variable, which is a dictionary containing the input
values in the UI like sliders and checkboxes.
"""
context = get_or_create_context(state)
if context is None or context.is_completed:
return state
if (
context.params.trigger_mode == TriggerMode.TOOL
or context.params.dont_stream_when_generating_images
):
state["stream"] = False
shared.processing_message = (
picture_processing_message
if context.params.dont_stream_when_generating_images
else default_processing_message
)
return state
def history_modifier(history: List[str]) -> List[str]:
"""
Modifies the chat history.
Only used in chat mode.
"""
context = get_current_context()
if context is None or context.is_completed:
return history
# todo: strip <img> tags from history
return history
def cleanup_context() -> None:
context = get_current_context()
if context is not None:
context.is_completed = True
set_current_context(None)
shared.processing_message = default_processing_message
pass
def output_modifier(string: str, state: dict, is_chat: bool = False) -> str:
"""
Modifies the LLM output before it gets presented.
In chat mode, the modified version goes into history['visible'],
and the original version goes into history['internal'].
"""
global params
if not is_chat:
cleanup_context()
return string
context = get_current_context()
if context is None or context.is_completed:
ext_params = StableDiffusionWebUiExtensionParams(**params)
ext_params.normalize()
if ext_params.trigger_mode == TriggerMode.INTERACTIVE:
output_regex = ext_params.interactive_mode_output_trigger_regex
normalized_message = html.unescape(string).strip()
if output_regex and re.match(
output_regex, normalized_message, re.IGNORECASE
):
sd_client = SdWebUIApi(
baseurl=ext_params.api_endpoint,
username=ext_params.api_username,
password=ext_params.api_password,
)
context = GenerationContext(
params=ext_params,
sd_client=sd_client,
input_text=state.get("input", ""),
state=state,
)
set_current_context(context)
if context is None or context.is_completed:
cleanup_context()
return string
context.state = state
context.output_text = string
if "<img " in string:
cleanup_context()
return string
try:
string, images_html, prompt, _, _, _ = generate_html_images_for_context(context)
string = html.escape(string)
if images_html:
string = f"{string}\n\n{images_html}"
if prompt and (
context.params.trigger_mode == TriggerMode.TOOL
or (
context.params.trigger_mode == TriggerMode.INTERACTIVE
and context.params.interactive_mode_prompt_generation_mode
== InteractiveModePromptGenerationMode.DYNAMIC
)
):
string = f"{string}\n*{html.escape(prompt).strip()}*"
except Exception as e:
string += "\n\n*Image generation has failed. Check logs for errors.*"
logger.error(e, exc_info=True)
cleanup_context()
return string
def logits_processor_modifier(processor_list: List[LogitsProcessor], input_ids):
"""
Adds logits processors to the list, allowing you to access and modify
the next token probabilities.
Only used by loaders that use the transformers library for sampling.
"""
global cached_schema_text, cached_schema_logits
context = get_current_context()
if (
context is None
or context.is_completed
or context.params.trigger_mode != TriggerMode.TOOL
or not context.params.tool_mode_force_json_output_enabled
or not isinstance(shared.tokenizer, PreTrainedTokenizer)
):
return processor_list
schema_text = context.params.tool_mode_force_json_output_schema or ""
if len(schema_text.strip()) == 0:
return processor_list
if cached_schema_text != schema_text or cached_schema_logits is None:
try:
schema = parse_schema_from_string(schema_text)
except Exception as e:
logger.error(
"Failed to parse JSON schema: %s,\nSchema: %s",
repr(e),
schema_text,
exc_info=True,
)
cached_schema_logits = JsonSchemaLogitsProcessor(schema, shared.tokenizer) # type: ignore
cached_schema_text = schema_text
assert cached_schema_logits is not None, "cached_schema_logits is None"
processor_list.append(cached_schema_logits)
return processor_list
def ui() -> None:
"""
Gets executed when the UI is drawn. Custom gradio elements and
their corresponding event handlers should be defined here.
To learn about gradio components, check out the docs:
https://gradio.app/docs/
"""
global ui_params
ui_params = StableDiffusionWebUiExtensionParams(**params)
render_ui(ui_params)