DiffusionGenerator / diffusion.py
srijaydeshpande's picture
Update diffusion.py
16b73bf verified
import torch
import matplotlib.pyplot as plt
from torchvision.utils import save_image
from torchvision import transforms
from torch.utils.data import DataLoader
import numpy as np, os
from torch import nn
import math
import torch.nn.functional as F
from torch.optim import Adam
from typing import Optional
import random
def mkdir(dir):
if not os.path.exists(dir):
os.makedirs(dir)
def get_beta_schedule(beta_schedule, beta_start, beta_end, num_diffusion_timesteps):
def sigmoid(x):
return 1 / (np.exp(-x) + 1)
if beta_schedule == "quad":
betas = (
np.linspace(
beta_start ** 0.5,
beta_end ** 0.5,
num_diffusion_timesteps,
dtype=np.float64,
)
** 2
)
elif beta_schedule == "linear":
betas = np.linspace(
beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
)
elif beta_schedule == "const":
betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1
betas = 1.0 / np.linspace(
num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64
)
elif beta_schedule == "sigmoid":
betas = np.linspace(-6, 6, num_diffusion_timesteps)
betas = sigmoid(betas) * (beta_end - beta_start) + beta_start
else:
raise NotImplementedError(beta_schedule)
assert betas.shape == (num_diffusion_timesteps,)
betas = torch.from_numpy(betas).float()
return betas
def get_index_from_list(vals, t, x_shape):
"""
Returns a specific index t of a passed list of values vals
while considering the batch dimension.
"""
batch_size = t.shape[0]
out = vals.gather(-1, t.cpu())
return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)
def forward_diffusion_sample(x, t, device="cpu"):
"""
Takes an image and a timestep as input and
returns the noisy version of it
"""
noise = torch.randn_like(x) # gaussian noise
# noise = torch.FloatTensor(x.shape).uniform_(-1, 1) #uniform distribution noise
sqrt_alphas_cumprod_t = get_index_from_list(sqrt_alphas_cumprod, t, x.shape)
sqrt_one_minus_alphas_cumprod_t = get_index_from_list(
sqrt_one_minus_alphas_cumprod, t, x.shape
)
# print("coeff stats ",sqrt_alphas_cumprod_t, " and ", sqrt_one_minus_alphas_cumprod_t)
# mean + variance
return sqrt_alphas_cumprod_t.to(device) * x.to(device) \
+ sqrt_one_minus_alphas_cumprod_t.to(device) * noise.to(device), noise.to(device)
class Block(nn.Module):
def __init__(self, in_ch, out_ch, time_emb_dim, up=False):
super().__init__()
self.time_mlp = nn.Linear(time_emb_dim, out_ch)
if up:
self.conv1 = nn.Conv2d(2 * in_ch, out_ch, 3, padding=1)
self.transform = nn.ConvTranspose2d(out_ch, out_ch, 4, 2, 1)
else:
self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
self.transform = nn.Conv2d(out_ch, out_ch, 4, 2, 1)
self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
self.bnorm1 = nn.BatchNorm2d(out_ch)
self.bnorm2 = nn.BatchNorm2d(out_ch)
self.relu = nn.LeakyReLU(0.2)
def forward(self, x, t, ):
# First Conv
h = self.bnorm1(self.relu(self.conv1(x)))
# Time embedding
time_emb = self.relu(self.time_mlp(t))
# Extend last 2 dimensions
time_emb = time_emb[(...,) + (None,) * 2]
# Add time channel
h = h + time_emb
# Second Conv
h = self.bnorm2(self.relu(self.conv2(h)))
# Down or Upsample
return self.transform(h)
class SinusoidalPositionEmbeddings(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, time):
device = time.device
half_dim = self.dim // 2
embeddings = math.log(10000) / (half_dim - 1)
embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
embeddings = time[:, None] * embeddings[None, :]
embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
return embeddings
class CrossAttention(nn.Module):
"""
### Cross Attention Layer
This falls-back to self-attention when conditional embeddings are not specified.
"""
use_flash_attention: bool = True
def __init__(self, d_model: int, d_cond: int, n_heads: int, d_head: int, is_inplace: bool = False):
"""
:param d_model: is the input embedding size
:param n_heads: is the number of attention heads
:param d_head: is the size of a attention head
:param d_cond: is the size of the conditional embeddings
:param is_inplace: specifies whether to perform the attention softmax computation inplace to
save memory
"""
super().__init__()
self.is_inplace = is_inplace
self.n_heads = n_heads
self.d_head = d_head
# Attention scaling factor
self.scale = d_head ** -0.5
# Query, key and value mappings
d_attn = d_head * n_heads
self.to_q = nn.Linear(d_model, d_attn, bias=False)
self.to_k = nn.Linear(d_cond, d_attn, bias=False)
self.to_v = nn.Linear(d_cond, d_attn, bias=False)
# Final linear layer
self.to_out = nn.Sequential(nn.Linear(d_attn, d_model))
def forward(self, x: torch.Tensor, cond: Optional[torch.Tensor] = None):
"""
:param x: are the input embeddings of shape `[batch_size, height * width, d_model]`
:param cond: is the conditional embeddings of shape `[batch_size, n_cond, d_cond]`
"""
# If `cond` is `None` we perform self attention
has_cond = cond is not None
if not has_cond:
cond = x
# Get query, key and value vectors
q = self.to_q(x)
k = self.to_k(cond)
v = self.to_v(cond)
return self.normal_attention(q, k, v)
def normal_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
"""
#### Normal Attention
:param q: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
:param k: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
:param v: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
"""
# Split them to heads of shape `[batch_size, seq_len, n_heads, d_head]`
q = q.view(*q.shape[:2], self.n_heads, -1)
k = k.view(*k.shape[:2], self.n_heads, -1)
v = v.view(*v.shape[:2], self.n_heads, -1)
# Calculate attention $\frac{Q K^\top}{\sqrt{d_{key}}}$
attn = torch.einsum('bihd,bjhd->bhij', q, k) * self.scale
# Compute softmax
# $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)$$
if self.is_inplace:
half = attn.shape[0] // 2
attn[half:] = attn[half:].softmax(dim=-1)
attn[:half] = attn[:half].softmax(dim=-1)
else:
attn = attn.softmax(dim=-1)
# Compute attention output
# $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)V$$
out = torch.einsum('bhij,bjhd->bihd', attn, v)
# Reshape to `[batch_size, height * width, n_heads * d_head]`
out = out.reshape(*out.shape[:2], -1)
# Map to `[batch_size, height * width, d_model]` with a linear layer
return self.to_out(out)
class SimpleUnet(nn.Module):
def __init__(self):
super().__init__()
image_channels = 3
# down_channels = (64, 128, 256, 512, 1024)
# up_channels = (1024, 512, 256, 128, 64)
down_channels = (16, 32, 64, 128, 256)
up_channels = (256, 128, 64, 32, 16)
out_dim = 1
time_emb_dim = 32
# Time embedding
self.time_mlp = nn.Sequential(
SinusoidalPositionEmbeddings(time_emb_dim),
nn.Linear(time_emb_dim, time_emb_dim),
nn.ReLU()
)
# Initial projection
self.conv0 = nn.Conv2d(image_channels, down_channels[0], 3, padding=1)
# Downsample
self.downs = nn.ModuleList([Block(down_channels[i], down_channels[i + 1], \
time_emb_dim) \
for i in range(len(down_channels) - 1)])
# Upsample
self.ups = nn.ModuleList([Block(up_channels[i], up_channels[i + 1], \
time_emb_dim, up=True) \
for i in range(len(up_channels) - 1)])
self.silu = nn.SiLU()
self.output = nn.Conv2d(up_channels[-1], 3, out_dim)
self.apply_tanh = nn.Tanh()
self.cross_attention_module = CrossAttention(3, 32, 16, 16)
def forward(self, x, y, timestep):
# Embedd class condition using cross attention
batch_size = x.shape[0]
y = self.time_mlp(y)
y = y[:, None, :]
x = x.permute(0, 2, 3, 1).view(batch_size, IMG_SIZE * IMG_SIZE, 3)
x2 = x + self.cross_attention_module(x, y)
x2 = x2.view(batch_size, IMG_SIZE, IMG_SIZE, 3).permute(0, 3, 1, 2)
# Embedd time
t = self.time_mlp(timestep)
# Initial conv
x2 = self.conv0(x2)
# Unet
residual_inputs = []
for down in self.downs:
x2 = down(x2, t)
residual_inputs.append(x2)
for up in self.ups:
residual_x2 = residual_inputs.pop()
# Add residual x2 as additional channels
x2 = torch.cat((x2, residual_x2), dim=1)
x2 = up(x2, t)
x2 = self.silu(x2)
output = self.output(x2)
return output
def get_loss(model, x_0, t):
latent, condition = x_0 # both latents and condition have same shap
latent = latent.cuda()
condition = condition.cuda()
x_noisy, noise = forward_diffusion_sample(latent, t, device)
noise_pred = model(x_noisy, condition, t)
# return F.l1_loss(noise, noise_pred)
return F.mse_loss(noise, noise_pred)
@torch.no_grad()
def sample_timestep(x, model, y, t):
betas_t = get_index_from_list(betas, t, x.shape)
sqrt_one_minus_alphas_cumprod_t = get_index_from_list(
sqrt_one_minus_alphas_cumprod, t, x.shape
)
sqrt_recip_alphas_t = get_index_from_list(sqrt_recip_alphas, t, x.shape)
# Call model (current image - noise prediction)
model_mean = sqrt_recip_alphas_t * (
x - (betas_t / sqrt_one_minus_alphas_cumprod_t) * model(x, y, t)
)
posterior_variance_t = get_index_from_list(posterior_variance, t, x.shape)
# print("model prediction stats ",torch.max(model(x, y, t)), " and ", torch.min(model(x, y, t)))
if t == 0:
return model_mean
else:
noise = torch.randn_like(x)
return model_mean + torch.sqrt(posterior_variance_t) * noise
def show_tensor_image(image):
reverse_transforms = transforms.Compose([
transforms.Lambda(lambda t: (t + 1) / 2),
transforms.Lambda(lambda t: t.permute(1, 2, 0)), # CHW to HWC
transforms.Lambda(lambda t: t * 255.),
transforms.Lambda(lambda t: t.numpy().astype(np.uint8)),
transforms.ToPILImage(),
])
# Take first image of batch
if len(image.shape) == 4:
image = image[0, :, :, :]
plt.imshow(reverse_transforms(image))
def generate_latent(model_dir, cancer_type, output_dir):
if (cancer_type == 'benign'):
model_name = "digestpath_mask_benign_default.pt"
else:
model_name = "digestpath_mask_malignant_default.pt"
device = "cuda" if torch.cuda.is_available() else "cpu"
model_path = os.path.join(model_dir, model_name)
model = SimpleUnet()
model.to(device)
model.load_state_dict(torch.load(model_path))
print("model loaded")
model.eval()
# cancer_grade = random.randint(0, 1)
condition = torch.tensor([1]).cuda() # benign:0/malignant:1 grade cancer
# condition = torch.full([1, 1, IMG_SIZE, IMG_SIZE], condition).float().cuda()
img = torch.randn((1, 3, IMG_SIZE, IMG_SIZE), device=device)
for j in range(0, T)[::-1]:
t = torch.full((1,), j, device=device, dtype=torch.long)
img = sample_timestep(img, model, condition, t)
print("sampled image ", torch.max(img), " and ", torch.min(img))
save_image(img, os.path.join(output_dir, "sample.png"))
torch.save(img, os.path.join(output_dir, "sample.pt"))
# Define beta schedule
T = 1000
IMG_SIZE = 64
betas = get_beta_schedule(beta_schedule="linear",
beta_start=0.0001,
beta_end=0.02,
num_diffusion_timesteps=T)
# Pre-calculate different terms for closed form
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod) # root(alpha_bar)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod) # root(1-alpha_bar)
sqrt_recip_alphas = torch.sqrt(1.0 / alphas) # 1/root(alpha)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
# model_dir = "trained_models/diffusion"
# output_dir = r"F:\Datasets\DigestPath\scene_generation\all\1000\256\test\output\benign"
# generate_latent(model_dir, 'malignant', output_dir)