Linoy Tsaban commited on
Commit
3fcb5ce
·
1 Parent(s): cb271cd

Update modified_pipeline_semantic_stable_diffusion.py

Browse files
modified_pipeline_semantic_stable_diffusion.py CHANGED
@@ -717,37 +717,37 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline):
717
  callback(i, t, latents)
718
 
719
 
720
- # 8. Post-processing
721
- image = self.decode_latents(latents)
722
 
723
- # 9. Run safety checker
724
- image, has_nsfw_concept = self.run_safety_checker(image, self.device, text_embeddings.dtype)
725
 
726
- # 10. Convert to PIL
727
- if output_type == "pil":
728
- image = self.numpy_to_pil(image)
729
 
730
- if not return_dict:
731
- return (image, has_nsfw_concept)
732
 
733
- return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
734
 
735
- # # 8. Post-processing
736
- # if not output_type == "latent":
737
- # image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
738
- # image, has_nsfw_concept = self.run_safety_checker(image, self.device, text_embeddings.dtype)
739
- # else:
740
- # image = latents
741
- # has_nsfw_concept = None
742
 
743
- # if has_nsfw_concept is None:
744
- # do_denormalize = [True] * image.shape[0]
745
- # else:
746
- # do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
747
 
748
- # image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
749
 
750
- # if not return_dict:
751
- # return (image, has_nsfw_concept)
752
 
753
- # return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
 
717
  callback(i, t, latents)
718
 
719
 
720
+ # # 8. Post-processing
721
+ # image = self.decode_latents(latents)
722
 
723
+ # # 9. Run safety checker
724
+ # image, has_nsfw_concept = self.run_safety_checker(image, self.device, text_embeddings.dtype)
725
 
726
+ # # 10. Convert to PIL
727
+ # if output_type == "pil":
728
+ # image = self.numpy_to_pil(image)
729
 
730
+ # if not return_dict:
731
+ # return (image, has_nsfw_concept)
732
 
733
+ # return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
734
 
735
+ # 8. Post-processing
736
+ if not output_type == "latent":
737
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
738
+ image, has_nsfw_concept = self.run_safety_checker(image, self.device, text_embeddings.dtype)
739
+ else:
740
+ image = latents
741
+ has_nsfw_concept = None
742
 
743
+ if has_nsfw_concept is None:
744
+ do_denormalize = [True] * image.shape[0]
745
+ else:
746
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
747
 
748
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
749
 
750
+ if not return_dict:
751
+ return (image, has_nsfw_concept)
752
 
753
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)