ddoc commited on
Commit
739a2ab
·
1 Parent(s): 0bf2c3b

Upload 10 files

Browse files
Files changed (10) hide show
  1. .gitignore +2 -0
  2. LICENSE.txt +21 -0
  3. README.md +4 -0
  4. explanation.html +49 -0
  5. install.py +13 -0
  6. javascript/promptgen.js +22 -0
  7. requirements.txt +2 -0
  8. screenshot.png +0 -0
  9. scripts/promptgen.py +282 -0
  10. style.css +54 -0
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ __pycache__
2
+ /models
LICENSE.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 AUTOMATIC1111
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Prompt generator
2
+ An extension for [webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) that lets you generate prompts.
3
+
4
+ ![](screenshot.png)
explanation.html ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div id="promptgen_explanation_show" >
2
+ <a style="font-weight: bold; cursor: pointer" onclick="gradioApp().getElementById('promptgen_explanation').style.display=''; gradioApp().getElementById('promptgen_explanation_show').style.display='none'; return false">
3
+ Information
4
+ </a>
5
+ </div>
6
+ <div style="display:none" id="promptgen_explanation">
7
+ <table>
8
+ <thead>
9
+ <tr>
10
+ <th>Name</th>
11
+ <th>Description</th>
12
+ </tr>
13
+ </thead>
14
+ <tbody>
15
+ <tr>
16
+ <td>Top K</td>
17
+ <td>When appending a word to the prompt, pick out of K most likely candidates.</td>
18
+ </tr>
19
+ <tr>
20
+ <td>Sampling mode</td>
21
+ <td>When appending a word to the prompt, pick out of most likely candidates whose total probability is reater than P.</td>
22
+ </tr>
23
+ <tr>
24
+ <td>Number of beams</td>
25
+ <td>Track multiple copies of each prompt as it's being generated, and when done pick one with most likelihood.</td>
26
+ </tr>
27
+ <tr>
28
+ <td>Temperature</td>
29
+ <td>When appending a word to the prompt, the greater temperature is, the more chance to pick an unlikely candidate. At 0, all generated prompts are the same.</td>
30
+ </tr>
31
+ <tr>
32
+ <td>Repetition penalty</td>
33
+ <td>The greater the value is, the less likely repeated tearms are to appear in prompt.</td>
34
+ </tr>
35
+ <tr>
36
+ <td>Length preference</td>
37
+ <td>Negative values tend to produce shorter prompt, positive - longer. Only works with Number of beams > 0.</td>
38
+ </tr>
39
+ <tr>
40
+ <td>Min length</td>
41
+ <td>Minimum length of generated prompt in tokens.</td>
42
+ </tr>
43
+ <tr>
44
+ <td>Max length</td>
45
+ <td>Maximum length of generated prompt in tokens.</td>
46
+ </tr>
47
+ </tbody>
48
+ </table>
49
+ </div>
install.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import launch
2
+ import os
3
+
4
+ current_dir = os.path.dirname(os.path.realpath(__file__))
5
+ req_file = os.path.join(current_dir, "requirements.txt")
6
+
7
+ with open(req_file) as file:
8
+ for lib in file:
9
+ lib = lib.strip()
10
+ if not launch.is_installed(lib):
11
+ launch.run_pip(
12
+ f"install {lib}",
13
+ f"danbooru-tag-gen requirement: {lib}")
javascript/promptgen.js ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ function promptgen_send_to(where, text){
3
+ textarea = gradioApp().querySelector('#promptgen_selected_text textarea')
4
+ textarea.value = text
5
+ updateInput(textarea)
6
+
7
+ gradioApp().querySelector('#promptgen_send_to_'+where).click()
8
+
9
+ where == 'txt2img' ? switch_to_txt2img() : switch_to_img2img()
10
+ }
11
+
12
+ function promptgen_send_to_txt2img(text){ promptgen_send_to('txt2img', text) }
13
+ function promptgen_send_to_img2img(text){ promptgen_send_to('img2img', text) }
14
+
15
+ function submit_promptgen(){
16
+ var id = randomId()
17
+ requestProgress(id, gradioApp().getElementById('promptgen_results_column'), null, function(){})
18
+
19
+ var res = create_submit_args(arguments)
20
+ res[0] = id
21
+ return res
22
+ }
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ transformers==4.30.1
2
+ auto_gptq==0.2.2
screenshot.png ADDED
scripts/promptgen.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import html
2
+ import os
3
+ import time
4
+
5
+ import torch
6
+ import transformers
7
+ from transformers import AutoTokenizer
8
+ from auto_gptq import AutoGPTQForCausalLM
9
+
10
+ from modules import shared, generation_parameters_copypaste
11
+
12
+ from modules import scripts, script_callbacks, devices, ui
13
+ import gradio as gr
14
+
15
+ from modules.ui_components import FormRow
16
+
17
+
18
+ class Model:
19
+ name = None
20
+ model = None
21
+ tokenizer = None
22
+
23
+
24
+ available_models = []
25
+ current = Model()
26
+
27
+ base_dir = scripts.basedir()
28
+ models_dir = os.path.join(base_dir, "models")
29
+
30
+
31
+ def device():
32
+ return devices.cpu if shared.opts.promptgen_device == 'cpu' else devices.device
33
+
34
+
35
+ def list_available_models():
36
+ available_models.clear()
37
+
38
+ os.makedirs(models_dir, exist_ok=True)
39
+
40
+ for dirname in os.listdir(models_dir):
41
+ if os.path.isdir(os.path.join(models_dir, dirname)):
42
+ available_models.append(dirname)
43
+
44
+ for name in [x.strip() for x in shared.opts.promptgen_names.split(",")]:
45
+ if not name:
46
+ continue
47
+
48
+ available_models.append(name)
49
+
50
+
51
+ def get_model_path(name):
52
+ dirname = os.path.join(models_dir, name)
53
+ if not os.path.isdir(dirname):
54
+ return name
55
+
56
+ return dirname
57
+
58
+
59
+ def generate_batch(input_ids, min_length, max_length, num_beams, temperature, repetition_penalty, length_penalty, sampling_mode, top_k, top_p):
60
+ top_p = float(top_p) if sampling_mode == 'Top P' else None
61
+ top_k = int(top_k) if sampling_mode == 'Top K' else None
62
+
63
+ outputs = current.model.generate(
64
+ input_ids,
65
+ do_sample=True,
66
+ temperature=max(float(temperature), 1e-6),
67
+ repetition_penalty=repetition_penalty,
68
+ length_penalty=length_penalty,
69
+ top_p=top_p,
70
+ top_k=top_k,
71
+ num_beams=int(num_beams),
72
+ min_length=min_length,
73
+ max_length=max_length,
74
+ pad_token_id=current.tokenizer.pad_token_id or current.tokenizer.eos_token_id
75
+ )
76
+ texts = current.tokenizer.batch_decode(outputs, skip_special_tokens=True)
77
+ return texts
78
+
79
+
80
+ def model_selection_changed(model_name):
81
+ if model_name == "None":
82
+ current.tokenizer = None
83
+ current.model = None
84
+ current.name = None
85
+
86
+ devices.torch_gc()
87
+
88
+
89
+ def generate(id_task, model_name, batch_count, batch_size, text, *args):
90
+ shared.state.textinfo = "Loading model..."
91
+ shared.state.job_count = batch_count
92
+ model_name = 'qwopqwop/danbooru-llama-gptq'
93
+
94
+ if current.name != model_name:
95
+ current.tokenizer = None
96
+ current.model = None
97
+ current.name = None
98
+
99
+ if model_name != 'None':
100
+ model = AutoGPTQForCausalLM.from_quantized("qwopqwop/danbooru-llama-gptq").model
101
+ current.model = model
102
+
103
+ DEFAULT_PAD_TOKEN = "[PAD]"
104
+
105
+ tokenizer = AutoTokenizer.from_pretrained("pinkmanlove/llama-7b-hf", use_fast=False)
106
+
107
+ def smart_tokenizer_and_embedding_resize(
108
+ special_tokens_dict,
109
+ tokenizer,
110
+ model,
111
+ ):
112
+ """Resize tokenizer and embedding.
113
+
114
+ Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
115
+ """
116
+ num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
117
+ model.resize_token_embeddings(len(tokenizer))
118
+
119
+ if num_new_tokens > 0:
120
+ input_embeddings = model.get_input_embeddings().weight.data
121
+ output_embeddings = model.get_output_embeddings().weight.data
122
+
123
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
124
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
125
+
126
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
127
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
128
+
129
+ if tokenizer._pad_token is None:
130
+ smart_tokenizer_and_embedding_resize(
131
+ special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN),
132
+ tokenizer=tokenizer,
133
+ model=model)
134
+
135
+ tokenizer.add_special_tokens({"eos_token": tokenizer.convert_ids_to_tokens(model.config.eos_token_id),
136
+ "bos_token": tokenizer.convert_ids_to_tokens(model.config.bos_token_id),
137
+ "unk_token": tokenizer.convert_ids_to_tokens(model.config.pad_token_id if model.config.pad_token_id != -1 else tokenizer.pad_token_id),})
138
+
139
+ current.tokenizer = tokenizer
140
+ current.name = model_name
141
+
142
+ assert current.model, 'No model available'
143
+ assert current.tokenizer, 'No tokenizer available'
144
+
145
+ current.model.to(device())
146
+
147
+ shared.state.textinfo = ""
148
+
149
+ input_ids = current.tokenizer(text, return_tensors="pt").input_ids
150
+ if input_ids.shape[1] == 0:
151
+ input_ids = torch.asarray([[current.tokenizer.bos_token_id]], dtype=torch.long)
152
+ input_ids = input_ids.to(device())
153
+ input_ids = input_ids.repeat((batch_size, 1))
154
+
155
+ markup = '<table><tbody>'
156
+
157
+ index = 0
158
+ for i in range(batch_count):
159
+ texts = generate_batch(input_ids, *args)
160
+ shared.state.nextjob()
161
+ for generated_text in texts:
162
+ index += 1
163
+ markup += f"""
164
+ <tr>
165
+ <td>
166
+ <div class="prompt gr-box gr-text-input">
167
+ <p id='promptgen_res_{index}'>{html.escape(generated_text)}</p>
168
+ </div>
169
+ </td>
170
+ <td class="sendto">
171
+ <a class='gr-button gr-button-lg gr-button-secondary' onclick="promptgen_send_to_txt2img(gradioApp().getElementById('promptgen_res_{index}').textContent)">to txt2img</a>
172
+ <a class='gr-button gr-button-lg gr-button-secondary' onclick="promptgen_send_to_img2img(gradioApp().getElementById('promptgen_res_{index}').textContent)">to img2img</a>
173
+ </td>
174
+ </tr>
175
+ """
176
+
177
+ markup += '</tbody></table>'
178
+
179
+ return markup, ''
180
+
181
+
182
+ def find_prompts(fields):
183
+ field_prompt = [x for x in fields if x[1] == "Prompt"][0]
184
+ field_negative_prompt = [x for x in fields if x[1] == "Negative prompt"][0]
185
+ return [field_prompt[0], field_negative_prompt[0]]
186
+
187
+
188
+ def send_prompts(text):
189
+ params = generation_parameters_copypaste.parse_generation_parameters(text)
190
+ negative_prompt = params.get("Negative prompt", "")
191
+ return params.get("Prompt", ""), negative_prompt or gr.update()
192
+
193
+
194
+ def add_tab():
195
+ list_available_models()
196
+
197
+ with gr.Blocks(analytics_enabled=False) as tab:
198
+ with gr.Row():
199
+ with gr.Column(scale=80):
200
+ prompt = gr.Textbox(label="Prompt", elem_id="promptgen_prompt", show_label=False, lines=2, placeholder="Beginning of the prompt (press Ctrl+Enter or Alt+Enter to generate)").style(container=False)
201
+ with gr.Column(scale=10):
202
+ submit = gr.Button('Generate', elem_id="promptgen_generate", variant='primary')
203
+
204
+ with gr.Row(elem_id="promptgen_main"):
205
+ with gr.Column(variant="compact"):
206
+ selected_text = gr.TextArea(elem_id='promptgen_selected_text', visible=False)
207
+ send_to_txt2img = gr.Button(elem_id='promptgen_send_to_txt2img', visible=False)
208
+ send_to_img2img = gr.Button(elem_id='promptgen_send_to_img2img', visible=False)
209
+
210
+ with FormRow():
211
+ model_selection = gr.Dropdown(label="Model", elem_id="promptgen_model", value=available_models[0], choices=["None"] + available_models)
212
+
213
+ with FormRow():
214
+ sampling_mode = gr.Radio(label="Sampling mode", elem_id="promptgen_sampling_mode", value="Top K", choices=["Top K", "Top P"])
215
+ top_k = gr.Slider(label="Top K", elem_id="promptgen_top_k", value=12, minimum=1, maximum=50, step=1)
216
+ top_p = gr.Slider(label="Top P", elem_id="promptgen_top_p", value=0.15, minimum=0, maximum=1, step=0.001)
217
+
218
+ with gr.Row():
219
+ num_beams = gr.Slider(label="Number of beams", elem_id="promptgen_num_beams", value=1, minimum=1, maximum=8, step=1)
220
+ temperature = gr.Slider(label="Temperature", elem_id="promptgen_temperature", value=1, minimum=0, maximum=4, step=0.01)
221
+ repetition_penalty = gr.Slider(label="Repetition penalty", elem_id="promptgen_repetition_penalty", value=1, minimum=1, maximum=4, step=0.01)
222
+
223
+ with FormRow():
224
+ length_penalty = gr.Slider(label="Length preference", elem_id="promptgen_length_preference", value=1, minimum=-10, maximum=10, step=0.1)
225
+ min_length = gr.Slider(label="Min length", elem_id="promptgen_min_length", value=20, minimum=1, maximum=400, step=1)
226
+ max_length = gr.Slider(label="Max length", elem_id="promptgen_max_length", value=150, minimum=1, maximum=400, step=1)
227
+
228
+ with FormRow():
229
+ batch_count = gr.Slider(label="Batch count", elem_id="promptgen_batch_count", value=1, minimum=1, maximum=100, step=1)
230
+ batch_size = gr.Slider(label="Batch size", elem_id="promptgen_batch_size", value=10, minimum=1, maximum=100, step=1)
231
+
232
+ with open(os.path.join(base_dir, "explanation.html"), encoding="utf8") as file:
233
+ footer = file.read()
234
+ gr.HTML(footer)
235
+
236
+ with gr.Column():
237
+ with gr.Group(elem_id="promptgen_results_column"):
238
+ res = gr.HTML()
239
+ res_info = gr.HTML()
240
+
241
+ submit.click(
242
+ fn=ui.wrap_gradio_gpu_call(generate, extra_outputs=['']),
243
+ _js="submit_promptgen",
244
+ inputs=[model_selection, model_selection, batch_count, batch_size, prompt, min_length, max_length, num_beams, temperature, repetition_penalty, length_penalty, sampling_mode, top_k, top_p, ],
245
+ outputs=[res, res_info]
246
+ )
247
+
248
+ model_selection.change(
249
+ fn=model_selection_changed,
250
+ inputs=[model_selection],
251
+ outputs=[],
252
+ )
253
+
254
+ send_to_txt2img.click(
255
+ fn=send_prompts,
256
+ inputs=[selected_text],
257
+ outputs=find_prompts(ui.txt2img_paste_fields)
258
+ )
259
+
260
+ send_to_img2img.click(
261
+ fn=send_prompts,
262
+ inputs=[selected_text],
263
+ outputs=find_prompts(ui.img2img_paste_fields)
264
+ )
265
+
266
+ return [(tab, "Promptgen", "promptgen")]
267
+
268
+
269
+ def on_ui_settings():
270
+ section = ("promptgen", "Promptgen")
271
+
272
+ shared.opts.add_option("promptgen_names", shared.OptionInfo("qwopqwop/danbooru-llama-gptq", section=section))
273
+ shared.opts.add_option("promptgen_device", shared.OptionInfo("gpu", "Device to use for text generation", gr.Radio, {"choices": ["gpu"]}, section=section))
274
+
275
+ def on_unload():
276
+ current.model = None
277
+ current.tokenizer = None
278
+
279
+
280
+ script_callbacks.on_ui_tabs(add_tab)
281
+ script_callbacks.on_ui_settings(on_ui_settings)
282
+ script_callbacks.on_script_unloaded(on_unload)
style.css ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ #promptgen_generate{
3
+ height: 100%
4
+ }
5
+
6
+ #promptgen_main{
7
+ margin-top: 1em;
8
+ }
9
+
10
+ #tab_promptgen table tr{
11
+ height: 1px;
12
+ }
13
+
14
+ #tab_promptgen table tr td{
15
+ height: 100%;
16
+ padding: 0.3em;
17
+ }
18
+
19
+ #tab_promptgen .prompt{
20
+ border: 1px solid rgba(128, 128, 128, 0.2);
21
+ height: 100%;
22
+ }
23
+
24
+ #tab_promptgen .prompt p{
25
+ white-space: pre-line;
26
+ }
27
+
28
+ #tab_promptgen .sendto{
29
+ width: 8em;
30
+ }
31
+
32
+ #tab_promptgen .sendto a{
33
+ cursor: pointer;
34
+ display: block;
35
+ margin: 0.2em;
36
+ padding: 0.4em;
37
+ }
38
+
39
+ #tab_promptgen .gr-form{
40
+ border: none;
41
+ padding-bottom: 0.5em;
42
+ }
43
+
44
+ #promptgen_explanation table{
45
+ border-collapse: collapse;
46
+ }
47
+
48
+ #promptgen_explanation table td, #promptgen_explanation table th{
49
+ border: 1px solid rgba(128,128,128,0.1);
50
+ vertical-align: top;
51
+
52
+ }
53
+
54
+