File size: 1,031 Bytes
811d1c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
#!/usr/bin/env python3
from diffusers import UNet2DConditionModel
import torch

unet = UNet2DConditionModel.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", variant="fp16", torch_dtype=torch.float16)
unet.train()
unet.enable_gradient_checkpointing()
unet = unet.to("cuda:1")

batch_size = 8

sample = torch.randn((1, 4, 128, 128)).half().to(unet.device).repeat(batch_size, 1, 1, 1)
time_ids = (torch.arange(6) / 6)[None, :].half().to(unet.device).repeat(batch_size, 1)
encoder_hidden_states = torch.randn((1, 77, 2048)).half().to(unet.device).repeat(batch_size, 1, 1)
text_embeds = torch.randn((1, 1280)).half().to(unet.device).repeat(batch_size, 1)

out = unet(sample, 1.0, added_cond_kwargs={"time_ids": time_ids, "text_embeds": text_embeds}, encoder_hidden_states=encoder_hidden_states).sample

loss = ((out - sample) ** 2).mean()
loss.backward()

print(torch.cuda.max_memory_allocated(device=unet.device))


# no gradient checkpointing: 12,276,695,552
# curr gradient checkpointing: 10,862,276,096