-
Notifications
You must be signed in to change notification settings - Fork 18
/
config.py
57 lines (52 loc) · 2.33 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from dataclasses import dataclass, field
from pathlib import Path
from typing import Dict, List
@dataclass
class RunConfig:
# Guiding text prompt
prompt: str
# Whether to use Stable Diffusion v2.1
sd_2_1: bool = False
# Which token indices to alter with attend-and-excite
token_indices: List[int] = None
# Which random seeds to use when generating
seeds: List[int] = field(default_factory=lambda: [42])
# Path to save all outputs to
output_path: Path = Path('./outputs')
# Number of denoising steps
n_inference_steps: int = 50
# Text guidance scale
guidance_scale: float = 7.5
# Number of denoising steps to apply attend-and-excite
max_iter_to_alter: int = 25
# Resolution of UNet to compute attention maps over
attention_res: int = 16
# Whether to run standard SD or attend-and-excite
run_standard_sd: bool = False
# Dictionary defining the iterations and desired thresholds to apply iterative latent refinement in
thresholds: Dict[int, float] = field(default_factory=lambda: {0: 0.05, 10: 0.5, 20: 0.8})
# Scale factor for updating the denoised latent z_t
scale_factor: int = 20
# Start and end values used for scaling the scale factor - decays linearly with the denoising timestep
scale_range: tuple = field(default_factory=lambda: (1.0, 0.5))
# Whether to apply the Gaussian smoothing before computing the maximum attention value for each subject token
smooth_attentions: bool = True
# Standard deviation for the Gaussian smoothing
sigma: float = 0.5
# Kernel size for the Gaussian smoothing
kernel_size: int = 3
# Whether to save cross attention maps for the final results
save_cross_attention_maps: bool = False
# BoxDiff
bbox: List[list] = field(default_factory=lambda: [[], []])
color: List[str] = field(default_factory=lambda: ['blue', 'red', 'purple', 'orange', 'green', 'yellow', 'black'])
P: float = 0.2
# number of pixels around the corner to be selected
L: int = 1
refine: bool = True
gligen_phrases: List[str] = field(default_factory=lambda: ['', ''])
n_splits: int = 4
which_one: int = 1
eval_output_path: Path = Path('./outputs/eval')
def __post_init__(self):
self.output_path.mkdir(exist_ok=True, parents=True)