Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: add safe bank access handling in single and double stream blocks #52

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 44 additions & 32 deletions flux/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,9 @@ def ref_attention(q, k, v, pe, ref_config, ref_type, idx, txt_shape=256):
attn = torch.cat([attn_a, attn])
max_val = 1.0
attn2 = attention(q[:1], k[1:], v[1:], pe=pe[:1], skip_rope=False, k_pe=pe[:1])
img_attn1 = attn[:1, 256 :]
img_attn1 = attn[:1, txt_shape :]
img_attn2 = attn2[:1, txt_shape :]
strength = min(max_val, ref_config[f'strengths'][ref_config['step']])
img_attn1 = attn[:1, txt_shape:]
img_attn2 = attn2[:1, txt_shape:]
strength = min(max_val, ref_config['strengths'][ref_config['step']])
attn[:1, txt_shape:] = img_attn1*(1-strength) + img_attn2*strength
attn[1:, :256] = attn[:1, :256]
return attn
Expand Down Expand Up @@ -72,24 +71,34 @@ def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, ref_config,
k = torch.cat((txt_k, img_k), dim=2)
v = torch.cat((txt_v, img_v), dim=2)

post_q_fn = transformer_options.get('patches_replace', {}).get(f'double', {}).get(('post_q', self.idx), None)
post_q_fn = transformer_options.get('patches_replace', {}).get('double', {}).get(('post_q', self.idx), None)
if post_q_fn is not None:
q = post_q_fn(q, transformer_options)

# Mask Patch
mask_fn = transformer_options.get('patches_replace', {}).get(f'double', {}).get(('mask_fn', self.idx), None)
mask_fn = transformer_options.get('patches_replace', {}).get('double', {}).get(('mask_fn', self.idx), None)
mask = None
if mask_fn is not None:
mask = mask_fn(q, transformer_options, 256)

rfedit = transformer_options.get('rfedit', {})
if rfedit.get('process', None) is not None and rfedit['double_layers'][str(self.idx)]:
pred = rfedit['pred']
step = rfedit['step']
if rfedit['process'] == 'forward':
rfedit['bank'][step][pred][self.idx] = v.cpu()
elif rfedit['process'] == 'reverse':
v = rfedit['bank'][step][pred][self.idx].to(v.device)
if rfedit.get('process', None) is not None and rfedit.get('double_layers', {}).get(str(self.idx), False):
pred = rfedit.get('pred')
step = rfedit.get('step')

# Safely handle the bank access
try:
if rfedit['process'] == 'forward':
if step not in rfedit['bank']:
rfedit['bank'][step] = {}
if pred not in rfedit['bank'][step]:
rfedit['bank'][step][pred] = {}
rfedit['bank'][step][pred][self.idx] = v.cpu()
elif rfedit['process'] == 'reverse':
if step in rfedit['bank'] and pred in rfedit['bank'][step] and self.idx in rfedit['bank'][step][pred]:
v = rfedit['bank'][step][pred][self.idx].to(v.device)
except Exception as e:
print(f"Warning: Error in rfedit processing for DoubleStreamBlock {self.idx}: {str(e)}")

rave_options = transformer_options.get('RAVE', None)
if ref_config is not None and ref_config['strengths'][ref_config['step']] > 0 and self.idx <= 20:
Expand All @@ -99,17 +108,14 @@ def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, ref_config,
else:
attn = attention(q, k, v, pe=pe, mask=mask)

txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
txt_attn, img_attn = attn[:, :txt.shape[1]], attn[:, txt.shape[1]:]
txt_attn = txt_attn[0:1].repeat(img_attn.shape[0], 1, 1)

# if self.idx % 8 == 0:
# img_attn = flow_attention(img_attn, self.num_heads, self.hidden_size // self.num_heads, transformer_options)

# calculate the img bloks
# calculate the img blocks
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)

# calculate the txt bloks
# calculate the txt blocks
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
return img, txt
Expand All @@ -124,23 +130,33 @@ def forward(self, x: Tensor, vec: Tensor, pe: Tensor, ref_config, timestep, tran
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
q, k = self.norm(q, k, v)

post_q_fn = transformer_options.get('patches_replace', {}).get(f'single', {}).get(('post_q', self.idx), None)
post_q_fn = transformer_options.get('patches_replace', {}).get('single', {}).get(('post_q', self.idx), None)
if post_q_fn is not None:
q = post_q_fn(q, transformer_options)

mask_fn = transformer_options.get('patches_replace', {}).get(f'single', {}).get(('mask_fn', self.idx), None)
mask_fn = transformer_options.get('patches_replace', {}).get('single', {}).get(('mask_fn', self.idx), None)
mask = None
if mask_fn is not None:
mask = mask_fn(q, transformer_options, 256)

rfedit = transformer_options.get('rfedit', {})
if rfedit.get('process', None) is not None and rfedit['single_layers'][str(self.idx)]:
pred = rfedit['pred']
step = rfedit['step']
if rfedit['process'] == 'forward':
rfedit['bank'][step][pred][self.idx] = v.cpu()
elif rfedit['process'] == 'reverse':
v = rfedit['bank'][step][pred][self.idx].to(v.device)
if rfedit.get('process', None) is not None and rfedit.get('single_layers', {}).get(str(self.idx), False):
pred = rfedit.get('pred')
step = rfedit.get('step')

# Safely handle the bank access
try:
if rfedit['process'] == 'forward':
if step not in rfedit['bank']:
rfedit['bank'][step] = {}
if pred not in rfedit['bank'][step]:
rfedit['bank'][step][pred] = {}
rfedit['bank'][step][pred][self.idx] = v.cpu()
elif rfedit['process'] == 'reverse':
if step in rfedit['bank'] and pred in rfedit['bank'][step] and self.idx in rfedit['bank'][step][pred]:
v = rfedit['bank'][step][pred][self.idx].to(v.device)
except Exception as e:
print(f"Warning: Error in rfedit processing for SingleStreamBlock {self.idx}: {str(e)}")

rave_options = transformer_options.get('RAVE', None)
if ref_config is not None and ref_config['single_strength'] > 0 and self.idx < 10:
Expand All @@ -153,17 +169,13 @@ def forward(self, x: Tensor, vec: Tensor, pe: Tensor, ref_config, timestep, tran
else:
attn = attention(q, k, v, pe=pe, mask=mask)
_, img_attn = attn[:, :256], attn[:, 256:]

# if self.idx % 8 == 0:
# img_attn = flow_attention(img_attn, self.num_heads, self.hidden_dim//self.num_heads, transformer_options)
attn[:, 256:] = img_attn

# compute activation in mlp stream, cat again and run second linear layer
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
return x + mod.gate * output



def inject_blocks(diffusion_model):
for i, block in enumerate(diffusion_model.double_blocks):
block.__class__ = DoubleStreamBlock
Expand Down