import torch import matplotlib.pyplot as plt from torchvision.utils import save_image from torchvision import transforms from torch.utils.data import DataLoader import numpy as np, os from torch import nn import math import torch.nn.functional as F from torch.optim import Adam from typing import Optional import random def mkdir(dir): if not os.path.exists(dir): os.makedirs(dir) def get_beta_schedule(beta_schedule, beta_start, beta_end, num_diffusion_timesteps): def sigmoid(x): return 1 / (np.exp(-x) + 1) if beta_schedule == "quad": betas = ( np.linspace( beta_start ** 0.5, beta_end ** 0.5, num_diffusion_timesteps, dtype=np.float64, ) ** 2 ) elif beta_schedule == "linear": betas = np.linspace( beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64 ) elif beta_schedule == "const": betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1 betas = 1.0 / np.linspace( num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64 ) elif beta_schedule == "sigmoid": betas = np.linspace(-6, 6, num_diffusion_timesteps) betas = sigmoid(betas) * (beta_end - beta_start) + beta_start else: raise NotImplementedError(beta_schedule) assert betas.shape == (num_diffusion_timesteps,) betas = torch.from_numpy(betas).float() return betas def get_index_from_list(vals, t, x_shape): """ Returns a specific index t of a passed list of values vals while considering the batch dimension. """ batch_size = t.shape[0] out = vals.gather(-1, t.cpu()) return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device) def forward_diffusion_sample(x, t, device="cpu"): """ Takes an image and a timestep as input and returns the noisy version of it """ noise = torch.randn_like(x) # gaussian noise # noise = torch.FloatTensor(x.shape).uniform_(-1, 1) #uniform distribution noise sqrt_alphas_cumprod_t = get_index_from_list(sqrt_alphas_cumprod, t, x.shape) sqrt_one_minus_alphas_cumprod_t = get_index_from_list( sqrt_one_minus_alphas_cumprod, t, x.shape ) # print("coeff stats ",sqrt_alphas_cumprod_t, " and ", sqrt_one_minus_alphas_cumprod_t) # mean + variance return sqrt_alphas_cumprod_t.to(device) * x.to(device) \ + sqrt_one_minus_alphas_cumprod_t.to(device) * noise.to(device), noise.to(device) class Block(nn.Module): def __init__(self, in_ch, out_ch, time_emb_dim, up=False): super().__init__() self.time_mlp = nn.Linear(time_emb_dim, out_ch) if up: self.conv1 = nn.Conv2d(2 * in_ch, out_ch, 3, padding=1) self.transform = nn.ConvTranspose2d(out_ch, out_ch, 4, 2, 1) else: self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1) self.transform = nn.Conv2d(out_ch, out_ch, 4, 2, 1) self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1) self.bnorm1 = nn.BatchNorm2d(out_ch) self.bnorm2 = nn.BatchNorm2d(out_ch) self.relu = nn.LeakyReLU(0.2) def forward(self, x, t, ): # First Conv h = self.bnorm1(self.relu(self.conv1(x))) # Time embedding time_emb = self.relu(self.time_mlp(t)) # Extend last 2 dimensions time_emb = time_emb[(...,) + (None,) * 2] # Add time channel h = h + time_emb # Second Conv h = self.bnorm2(self.relu(self.conv2(h))) # Down or Upsample return self.transform(h) class SinusoidalPositionEmbeddings(nn.Module): def __init__(self, dim): super().__init__() self.dim = dim def forward(self, time): device = time.device half_dim = self.dim // 2 embeddings = math.log(10000) / (half_dim - 1) embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings) embeddings = time[:, None] * embeddings[None, :] embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1) return embeddings class CrossAttention(nn.Module): """ ### Cross Attention Layer This falls-back to self-attention when conditional embeddings are not specified. """ use_flash_attention: bool = True def __init__(self, d_model: int, d_cond: int, n_heads: int, d_head: int, is_inplace: bool = False): """ :param d_model: is the input embedding size :param n_heads: is the number of attention heads :param d_head: is the size of a attention head :param d_cond: is the size of the conditional embeddings :param is_inplace: specifies whether to perform the attention softmax computation inplace to save memory """ super().__init__() self.is_inplace = is_inplace self.n_heads = n_heads self.d_head = d_head # Attention scaling factor self.scale = d_head ** -0.5 # Query, key and value mappings d_attn = d_head * n_heads self.to_q = nn.Linear(d_model, d_attn, bias=False) self.to_k = nn.Linear(d_cond, d_attn, bias=False) self.to_v = nn.Linear(d_cond, d_attn, bias=False) # Final linear layer self.to_out = nn.Sequential(nn.Linear(d_attn, d_model)) def forward(self, x: torch.Tensor, cond: Optional[torch.Tensor] = None): """ :param x: are the input embeddings of shape `[batch_size, height * width, d_model]` :param cond: is the conditional embeddings of shape `[batch_size, n_cond, d_cond]` """ # If `cond` is `None` we perform self attention has_cond = cond is not None if not has_cond: cond = x # Get query, key and value vectors q = self.to_q(x) k = self.to_k(cond) v = self.to_v(cond) return self.normal_attention(q, k, v) def normal_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): """ #### Normal Attention :param q: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]` :param k: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]` :param v: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]` """ # Split them to heads of shape `[batch_size, seq_len, n_heads, d_head]` q = q.view(*q.shape[:2], self.n_heads, -1) k = k.view(*k.shape[:2], self.n_heads, -1) v = v.view(*v.shape[:2], self.n_heads, -1) # Calculate attention $\frac{Q K^\top}{\sqrt{d_{key}}}$ attn = torch.einsum('bihd,bjhd->bhij', q, k) * self.scale # Compute softmax # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)$$ if self.is_inplace: half = attn.shape[0] // 2 attn[half:] = attn[half:].softmax(dim=-1) attn[:half] = attn[:half].softmax(dim=-1) else: attn = attn.softmax(dim=-1) # Compute attention output # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)V$$ out = torch.einsum('bhij,bjhd->bihd', attn, v) # Reshape to `[batch_size, height * width, n_heads * d_head]` out = out.reshape(*out.shape[:2], -1) # Map to `[batch_size, height * width, d_model]` with a linear layer return self.to_out(out) class SimpleUnet(nn.Module): def __init__(self): super().__init__() image_channels = 3 # down_channels = (64, 128, 256, 512, 1024) # up_channels = (1024, 512, 256, 128, 64) down_channels = (16, 32, 64, 128, 256) up_channels = (256, 128, 64, 32, 16) out_dim = 1 time_emb_dim = 32 # Time embedding self.time_mlp = nn.Sequential( SinusoidalPositionEmbeddings(time_emb_dim), nn.Linear(time_emb_dim, time_emb_dim), nn.ReLU() ) # Initial projection self.conv0 = nn.Conv2d(image_channels, down_channels[0], 3, padding=1) # Downsample self.downs = nn.ModuleList([Block(down_channels[i], down_channels[i + 1], \ time_emb_dim) \ for i in range(len(down_channels) - 1)]) # Upsample self.ups = nn.ModuleList([Block(up_channels[i], up_channels[i + 1], \ time_emb_dim, up=True) \ for i in range(len(up_channels) - 1)]) self.silu = nn.SiLU() self.output = nn.Conv2d(up_channels[-1], 3, out_dim) self.apply_tanh = nn.Tanh() self.cross_attention_module = CrossAttention(3, 32, 16, 16) def forward(self, x, y, timestep): # Embedd class condition using cross attention batch_size = x.shape[0] y = self.time_mlp(y) y = y[:, None, :] x = x.permute(0, 2, 3, 1).view(batch_size, IMG_SIZE * IMG_SIZE, 3) x2 = x + self.cross_attention_module(x, y) x2 = x2.view(batch_size, IMG_SIZE, IMG_SIZE, 3).permute(0, 3, 1, 2) # Embedd time t = self.time_mlp(timestep) # Initial conv x2 = self.conv0(x2) # Unet residual_inputs = [] for down in self.downs: x2 = down(x2, t) residual_inputs.append(x2) for up in self.ups: residual_x2 = residual_inputs.pop() # Add residual x2 as additional channels x2 = torch.cat((x2, residual_x2), dim=1) x2 = up(x2, t) x2 = self.silu(x2) output = self.output(x2) return output def get_loss(model, x_0, t): latent, condition = x_0 # both latents and condition have same shap latent = latent.cuda() condition = condition.cuda() x_noisy, noise = forward_diffusion_sample(latent, t, device) noise_pred = model(x_noisy, condition, t) # return F.l1_loss(noise, noise_pred) return F.mse_loss(noise, noise_pred) @torch.no_grad() def sample_timestep(x, model, y, t): betas_t = get_index_from_list(betas, t, x.shape) sqrt_one_minus_alphas_cumprod_t = get_index_from_list( sqrt_one_minus_alphas_cumprod, t, x.shape ) sqrt_recip_alphas_t = get_index_from_list(sqrt_recip_alphas, t, x.shape) # Call model (current image - noise prediction) model_mean = sqrt_recip_alphas_t * ( x - (betas_t / sqrt_one_minus_alphas_cumprod_t) * model(x, y, t) ) posterior_variance_t = get_index_from_list(posterior_variance, t, x.shape) # print("model prediction stats ",torch.max(model(x, y, t)), " and ", torch.min(model(x, y, t))) if t == 0: return model_mean else: noise = torch.randn_like(x) return model_mean + torch.sqrt(posterior_variance_t) * noise def show_tensor_image(image): reverse_transforms = transforms.Compose([ transforms.Lambda(lambda t: (t + 1) / 2), transforms.Lambda(lambda t: t.permute(1, 2, 0)), # CHW to HWC transforms.Lambda(lambda t: t * 255.), transforms.Lambda(lambda t: t.numpy().astype(np.uint8)), transforms.ToPILImage(), ]) # Take first image of batch if len(image.shape) == 4: image = image[0, :, :, :] plt.imshow(reverse_transforms(image)) def generate_latent(model_dir, cancer_type, output_dir): if (cancer_type == 'benign'): model_name = "digestpath_mask_benign_default.pt" else: model_name = "digestpath_mask_malignant_default.pt" device = "cuda" if torch.cuda.is_available() else "cpu" model_path = os.path.join(model_dir, model_name) model = SimpleUnet() model.to(device) model.load_state_dict(torch.load(model_path)) print("model loaded") model.eval() # cancer_grade = random.randint(0, 1) condition = torch.tensor([1]).cuda() # benign:0/malignant:1 grade cancer # condition = torch.full([1, 1, IMG_SIZE, IMG_SIZE], condition).float().cuda() img = torch.randn((1, 3, IMG_SIZE, IMG_SIZE), device=device) for j in range(0, T)[::-1]: t = torch.full((1,), j, device=device, dtype=torch.long) img = sample_timestep(img, model, condition, t) print("sampled image ", torch.max(img), " and ", torch.min(img)) save_image(img, os.path.join(output_dir, "sample.png")) torch.save(img, os.path.join(output_dir, "sample.pt")) # Define beta schedule T = 1000 IMG_SIZE = 64 betas = get_beta_schedule(beta_schedule="linear", beta_start=0.0001, beta_end=0.02, num_diffusion_timesteps=T) # Pre-calculate different terms for closed form alphas = 1. - betas alphas_cumprod = torch.cumprod(alphas, axis=0) sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod) # root(alpha_bar) sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod) # root(1-alpha_bar) sqrt_recip_alphas = torch.sqrt(1.0 / alphas) # 1/root(alpha) alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0) posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) # model_dir = "trained_models/diffusion" # output_dir = r"F:\Datasets\DigestPath\scene_generation\all\1000\256\test\output\benign" # generate_latent(model_dir, 'malignant', output_dir)