Skip to content

Commit

Permalink
Revert my safety_checker hack because I got errors with has_nsfw_conc…
Browse files Browse the repository at this point in the history
…epts can't iterate bool
  • Loading branch information
Skquark authored Dec 12, 2023
1 parent 58b3e00 commit 61d6948
Showing 1 changed file with 11 additions and 11 deletions.
22 changes: 11 additions & 11 deletions src/diffusers/pipelines/stable_diffusion/safety_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,14 +90,14 @@ def forward(self, clip_input, images):
else:
images[idx] = np.zeros(images[idx].shape) # black image

#if any(has_nsfw_concepts):
# logger.warning(
# "Potential NSFW content was detected in one or more images. A black image will be returned instead."
# " Try again with a different prompt and/or seed."
# )
if any(has_nsfw_concepts):
logger.warning(
"Potential NSFW content was detected in one or more images. A black image will be returned instead."
" Try again with a different prompt and/or seed."
)

#return images, has_nsfw_concepts
return images, False
return images, has_nsfw_concepts
#return images, False

@torch.no_grad()
def forward_onnx(self, clip_input: torch.FloatTensor, images: torch.FloatTensor):
Expand All @@ -118,10 +118,10 @@ def forward_onnx(self, clip_input: torch.FloatTensor, images: torch.FloatTensor)
special_adjustment = special_adjustment.unsqueeze(1).expand(-1, cos_dist.shape[1])

concept_scores = (cos_dist - self.concept_embeds_weights) + special_adjustment
# concept_scores = concept_scores.round(decimals=3)
#has_nsfw_concepts = torch.any(concept_scores > 0, dim=1)
concept_scores = concept_scores.round(decimals=3)
has_nsfw_concepts = torch.any(concept_scores > 0, dim=1)

#images[has_nsfw_concepts] = 0.0 # black image

#return images, has_nsfw_concepts
return images, False
return images, has_nsfw_concepts
#return images, False

0 comments on commit 61d6948

Please sign in to comment.