Has anyone run this using an M1?
I tried changing the device to MPS... but no dice..
torch.backends.mps.is_available() is true - but autocast is defeating me
Any ideas?
This script works on Apple M1
import requests
from PIL import Image
from io import BytesIO
from torch import autocast
from image_to_image import *
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'running on {device}')
pipei2i = StableDiffusionImg2ImgPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
#revision="fp16",
#torch_dtype=torch.float16,
use_auth_token=True
).to(device)
response = requests.get('https://pbs.twimg.com/media/Fa1_7_vWYAEwfX-.png')
init_image = Image.open(BytesIO(response.content)).convert("RGB")
init_image = init_image.resize((512, 512))
init_image = preprocess(init_image)
prompt = "a cat, artstation"
samples = 2
steps = 50
strength = 0.75
scale = 7.5
outputs = []
if device=='cuda':
with autocast("cuda"):
outputs = pipei2i(prompt=[prompt]*samples,
init_image=init_image,
strength=strength,
num_inference_steps=steps,
guidance_scale=scale)
else:
outputs = pipei2i(prompt=[prompt]*samples,
init_image=init_image,
strength=strength,
num_inference_steps=steps,
guidance_scale=scale)
safe_images = []
unsafe_images = []
# {'sample': [<PIL.Image.Image image mode=RGB size=512x512 at 0x7FEE48615510>], 'nsfw_content_detected': [False]}
for i, image in enumerate(outputs["sample"]):
if(outputs["nsfw_content_detected"][i]):
unsafe_images.append(image)
else:
safe_images.append(image)
for (index,image) in enumerate(safe_images):
image.save(f"safe_{index}.png")
for (index,image) in enumerate(unsafe_images):
image.save(f"unsafe_{index}.png")
requirements are
scipy
torch
transformers
diffusers
ftfy
This is what I can see from Activity Monitor during the operation
See here for more details and dependencies.
See discussions and instructions here for M1/M2 setup:
https://github.com/CompVis/stable-diffusion/issues/25#issuecomment-1229430133
@patrickvonplaten I now get
TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.
setting the fallback export PYTORCH_ENABLE_MPS_FALLBACK=1;
it works as I can clearly see it from logs:
UserWarning: The operator 'aten::index.Tensor' is not currently supported on the MPS backend and will fall back to run on the CPU. This may have performance implications. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/mps/MPSFallback.mm:11.)
pooled_output = last_hidden_state[torch.arange(last_hidden_state.shape[0]), input_ids.argmax(dim=-1)]
and in the code torch it recognizes mps
and I suppose loading the pipeline with fp32
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device = 'mps' if torch.backends.mps.is_available() else device
print(f'running on {device}')
...
pipe = StableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-3",
scheduler=lms,
use_auth_token=True
).to(device)
I'm using the latest diffusers
, while in the thread you shared some possibile solution is using forks / hacks / etc.
The whole traceback is
Traceback (most recent call last):
File "diffuser.py", line 65, in <module>
guidance_scale=scale,
File "/Projects/bloom/.venv/lib/python3.7/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
return func(*args, **kwargs)
File "/Projects/bloom/pipeline_stable_diffusion.py", line 142, in __call__
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]
File "/Users/musixmatch/Documents/Projects/bloom/.venv/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/Projects/bloom/.venv/lib/python3.7/site-packages/diffusers/models/unet_2d_condition.py", line 134, in forward
timesteps = timesteps[None].to(sample.device)
TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.
Thanks!
Hi @loretoparisi !
We are working on official support for mps
in Diffusers, hold tight for a couple days :)