Spaces:
Sleeping
Sleeping
import torch | |
import matplotlib.pyplot as plt | |
from torchvision.utils import save_image | |
from torchvision import transforms | |
from torch.utils.data import DataLoader | |
import numpy as np, os | |
from torch import nn | |
import math | |
import torch.nn.functional as F | |
from torch.optim import Adam | |
from typing import Optional | |
import random | |
def mkdir(dir): | |
if not os.path.exists(dir): | |
os.makedirs(dir) | |
def get_beta_schedule(beta_schedule, beta_start, beta_end, num_diffusion_timesteps): | |
def sigmoid(x): | |
return 1 / (np.exp(-x) + 1) | |
if beta_schedule == "quad": | |
betas = ( | |
np.linspace( | |
beta_start ** 0.5, | |
beta_end ** 0.5, | |
num_diffusion_timesteps, | |
dtype=np.float64, | |
) | |
** 2 | |
) | |
elif beta_schedule == "linear": | |
betas = np.linspace( | |
beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64 | |
) | |
elif beta_schedule == "const": | |
betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) | |
elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1 | |
betas = 1.0 / np.linspace( | |
num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64 | |
) | |
elif beta_schedule == "sigmoid": | |
betas = np.linspace(-6, 6, num_diffusion_timesteps) | |
betas = sigmoid(betas) * (beta_end - beta_start) + beta_start | |
else: | |
raise NotImplementedError(beta_schedule) | |
assert betas.shape == (num_diffusion_timesteps,) | |
betas = torch.from_numpy(betas).float() | |
return betas | |
def get_index_from_list(vals, t, x_shape): | |
""" | |
Returns a specific index t of a passed list of values vals | |
while considering the batch dimension. | |
""" | |
batch_size = t.shape[0] | |
out = vals.gather(-1, t.cpu()) | |
return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device) | |
def forward_diffusion_sample(x, t, device="cpu"): | |
""" | |
Takes an image and a timestep as input and | |
returns the noisy version of it | |
""" | |
noise = torch.randn_like(x) # gaussian noise | |
# noise = torch.FloatTensor(x.shape).uniform_(-1, 1) #uniform distribution noise | |
sqrt_alphas_cumprod_t = get_index_from_list(sqrt_alphas_cumprod, t, x.shape) | |
sqrt_one_minus_alphas_cumprod_t = get_index_from_list( | |
sqrt_one_minus_alphas_cumprod, t, x.shape | |
) | |
# print("coeff stats ",sqrt_alphas_cumprod_t, " and ", sqrt_one_minus_alphas_cumprod_t) | |
# mean + variance | |
return sqrt_alphas_cumprod_t.to(device) * x.to(device) \ | |
+ sqrt_one_minus_alphas_cumprod_t.to(device) * noise.to(device), noise.to(device) | |
class Block(nn.Module): | |
def __init__(self, in_ch, out_ch, time_emb_dim, up=False): | |
super().__init__() | |
self.time_mlp = nn.Linear(time_emb_dim, out_ch) | |
if up: | |
self.conv1 = nn.Conv2d(2 * in_ch, out_ch, 3, padding=1) | |
self.transform = nn.ConvTranspose2d(out_ch, out_ch, 4, 2, 1) | |
else: | |
self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1) | |
self.transform = nn.Conv2d(out_ch, out_ch, 4, 2, 1) | |
self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1) | |
self.bnorm1 = nn.BatchNorm2d(out_ch) | |
self.bnorm2 = nn.BatchNorm2d(out_ch) | |
self.relu = nn.LeakyReLU(0.2) | |
def forward(self, x, t, ): | |
# First Conv | |
h = self.bnorm1(self.relu(self.conv1(x))) | |
# Time embedding | |
time_emb = self.relu(self.time_mlp(t)) | |
# Extend last 2 dimensions | |
time_emb = time_emb[(...,) + (None,) * 2] | |
# Add time channel | |
h = h + time_emb | |
# Second Conv | |
h = self.bnorm2(self.relu(self.conv2(h))) | |
# Down or Upsample | |
return self.transform(h) | |
class SinusoidalPositionEmbeddings(nn.Module): | |
def __init__(self, dim): | |
super().__init__() | |
self.dim = dim | |
def forward(self, time): | |
device = time.device | |
half_dim = self.dim // 2 | |
embeddings = math.log(10000) / (half_dim - 1) | |
embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings) | |
embeddings = time[:, None] * embeddings[None, :] | |
embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1) | |
return embeddings | |
class CrossAttention(nn.Module): | |
""" | |
### Cross Attention Layer | |
This falls-back to self-attention when conditional embeddings are not specified. | |
""" | |
use_flash_attention: bool = True | |
def __init__(self, d_model: int, d_cond: int, n_heads: int, d_head: int, is_inplace: bool = False): | |
""" | |
:param d_model: is the input embedding size | |
:param n_heads: is the number of attention heads | |
:param d_head: is the size of a attention head | |
:param d_cond: is the size of the conditional embeddings | |
:param is_inplace: specifies whether to perform the attention softmax computation inplace to | |
save memory | |
""" | |
super().__init__() | |
self.is_inplace = is_inplace | |
self.n_heads = n_heads | |
self.d_head = d_head | |
# Attention scaling factor | |
self.scale = d_head ** -0.5 | |
# Query, key and value mappings | |
d_attn = d_head * n_heads | |
self.to_q = nn.Linear(d_model, d_attn, bias=False) | |
self.to_k = nn.Linear(d_cond, d_attn, bias=False) | |
self.to_v = nn.Linear(d_cond, d_attn, bias=False) | |
# Final linear layer | |
self.to_out = nn.Sequential(nn.Linear(d_attn, d_model)) | |
def forward(self, x: torch.Tensor, cond: Optional[torch.Tensor] = None): | |
""" | |
:param x: are the input embeddings of shape `[batch_size, height * width, d_model]` | |
:param cond: is the conditional embeddings of shape `[batch_size, n_cond, d_cond]` | |
""" | |
# If `cond` is `None` we perform self attention | |
has_cond = cond is not None | |
if not has_cond: | |
cond = x | |
# Get query, key and value vectors | |
q = self.to_q(x) | |
k = self.to_k(cond) | |
v = self.to_v(cond) | |
return self.normal_attention(q, k, v) | |
def normal_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): | |
""" | |
#### Normal Attention | |
:param q: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]` | |
:param k: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]` | |
:param v: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]` | |
""" | |
# Split them to heads of shape `[batch_size, seq_len, n_heads, d_head]` | |
q = q.view(*q.shape[:2], self.n_heads, -1) | |
k = k.view(*k.shape[:2], self.n_heads, -1) | |
v = v.view(*v.shape[:2], self.n_heads, -1) | |
# Calculate attention $\frac{Q K^\top}{\sqrt{d_{key}}}$ | |
attn = torch.einsum('bihd,bjhd->bhij', q, k) * self.scale | |
# Compute softmax | |
# $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)$$ | |
if self.is_inplace: | |
half = attn.shape[0] // 2 | |
attn[half:] = attn[half:].softmax(dim=-1) | |
attn[:half] = attn[:half].softmax(dim=-1) | |
else: | |
attn = attn.softmax(dim=-1) | |
# Compute attention output | |
# $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)V$$ | |
out = torch.einsum('bhij,bjhd->bihd', attn, v) | |
# Reshape to `[batch_size, height * width, n_heads * d_head]` | |
out = out.reshape(*out.shape[:2], -1) | |
# Map to `[batch_size, height * width, d_model]` with a linear layer | |
return self.to_out(out) | |
class SimpleUnet(nn.Module): | |
def __init__(self): | |
super().__init__() | |
image_channels = 3 | |
# down_channels = (64, 128, 256, 512, 1024) | |
# up_channels = (1024, 512, 256, 128, 64) | |
down_channels = (16, 32, 64, 128, 256) | |
up_channels = (256, 128, 64, 32, 16) | |
out_dim = 1 | |
time_emb_dim = 32 | |
# Time embedding | |
self.time_mlp = nn.Sequential( | |
SinusoidalPositionEmbeddings(time_emb_dim), | |
nn.Linear(time_emb_dim, time_emb_dim), | |
nn.ReLU() | |
) | |
# Initial projection | |
self.conv0 = nn.Conv2d(image_channels, down_channels[0], 3, padding=1) | |
# Downsample | |
self.downs = nn.ModuleList([Block(down_channels[i], down_channels[i + 1], \ | |
time_emb_dim) \ | |
for i in range(len(down_channels) - 1)]) | |
# Upsample | |
self.ups = nn.ModuleList([Block(up_channels[i], up_channels[i + 1], \ | |
time_emb_dim, up=True) \ | |
for i in range(len(up_channels) - 1)]) | |
self.silu = nn.SiLU() | |
self.output = nn.Conv2d(up_channels[-1], 3, out_dim) | |
self.apply_tanh = nn.Tanh() | |
self.cross_attention_module = CrossAttention(3, 32, 16, 16) | |
def forward(self, x, y, timestep): | |
# Embedd class condition using cross attention | |
batch_size = x.shape[0] | |
y = self.time_mlp(y) | |
y = y[:, None, :] | |
x = x.permute(0, 2, 3, 1).view(batch_size, IMG_SIZE * IMG_SIZE, 3) | |
x2 = x + self.cross_attention_module(x, y) | |
x2 = x2.view(batch_size, IMG_SIZE, IMG_SIZE, 3).permute(0, 3, 1, 2) | |
# Embedd time | |
t = self.time_mlp(timestep) | |
# Initial conv | |
x2 = self.conv0(x2) | |
# Unet | |
residual_inputs = [] | |
for down in self.downs: | |
x2 = down(x2, t) | |
residual_inputs.append(x2) | |
for up in self.ups: | |
residual_x2 = residual_inputs.pop() | |
# Add residual x2 as additional channels | |
x2 = torch.cat((x2, residual_x2), dim=1) | |
x2 = up(x2, t) | |
x2 = self.silu(x2) | |
output = self.output(x2) | |
return output | |
def get_loss(model, x_0, t): | |
latent, condition = x_0 # both latents and condition have same shap | |
latent = latent.cuda() | |
condition = condition.cuda() | |
x_noisy, noise = forward_diffusion_sample(latent, t, device) | |
noise_pred = model(x_noisy, condition, t) | |
# return F.l1_loss(noise, noise_pred) | |
return F.mse_loss(noise, noise_pred) | |
def sample_timestep(x, model, y, t): | |
betas_t = get_index_from_list(betas, t, x.shape) | |
sqrt_one_minus_alphas_cumprod_t = get_index_from_list( | |
sqrt_one_minus_alphas_cumprod, t, x.shape | |
) | |
sqrt_recip_alphas_t = get_index_from_list(sqrt_recip_alphas, t, x.shape) | |
# Call model (current image - noise prediction) | |
model_mean = sqrt_recip_alphas_t * ( | |
x - (betas_t / sqrt_one_minus_alphas_cumprod_t) * model(x, y, t) | |
) | |
posterior_variance_t = get_index_from_list(posterior_variance, t, x.shape) | |
# print("model prediction stats ",torch.max(model(x, y, t)), " and ", torch.min(model(x, y, t))) | |
if t == 0: | |
return model_mean | |
else: | |
noise = torch.randn_like(x) | |
return model_mean + torch.sqrt(posterior_variance_t) * noise | |
def show_tensor_image(image): | |
reverse_transforms = transforms.Compose([ | |
transforms.Lambda(lambda t: (t + 1) / 2), | |
transforms.Lambda(lambda t: t.permute(1, 2, 0)), # CHW to HWC | |
transforms.Lambda(lambda t: t * 255.), | |
transforms.Lambda(lambda t: t.numpy().astype(np.uint8)), | |
transforms.ToPILImage(), | |
]) | |
# Take first image of batch | |
if len(image.shape) == 4: | |
image = image[0, :, :, :] | |
plt.imshow(reverse_transforms(image)) | |
def generate_latent(model_dir, cancer_type, output_dir): | |
if (cancer_type == 'benign'): | |
model_name = "digestpath_mask_benign_default.pt" | |
else: | |
model_name = "digestpath_mask_malignant_default.pt" | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model_path = os.path.join(model_dir, model_name) | |
model = SimpleUnet() | |
model.to(device) | |
model.load_state_dict(torch.load(model_path)) | |
print("model loaded") | |
model.eval() | |
# cancer_grade = random.randint(0, 1) | |
condition = torch.tensor([1]).cuda() # benign:0/malignant:1 grade cancer | |
# condition = torch.full([1, 1, IMG_SIZE, IMG_SIZE], condition).float().cuda() | |
img = torch.randn((1, 3, IMG_SIZE, IMG_SIZE), device=device) | |
for j in range(0, T)[::-1]: | |
t = torch.full((1,), j, device=device, dtype=torch.long) | |
img = sample_timestep(img, model, condition, t) | |
print("sampled image ", torch.max(img), " and ", torch.min(img)) | |
save_image(img, os.path.join(output_dir, "sample.png")) | |
torch.save(img, os.path.join(output_dir, "sample.pt")) | |
# Define beta schedule | |
T = 1000 | |
IMG_SIZE = 64 | |
betas = get_beta_schedule(beta_schedule="linear", | |
beta_start=0.0001, | |
beta_end=0.02, | |
num_diffusion_timesteps=T) | |
# Pre-calculate different terms for closed form | |
alphas = 1. - betas | |
alphas_cumprod = torch.cumprod(alphas, axis=0) | |
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod) # root(alpha_bar) | |
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod) # root(1-alpha_bar) | |
sqrt_recip_alphas = torch.sqrt(1.0 / alphas) # 1/root(alpha) | |
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0) | |
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) | |
# model_dir = "trained_models/diffusion" | |
# output_dir = r"F:\Datasets\DigestPath\scene_generation\all\1000\256\test\output\benign" | |
# generate_latent(model_dir, 'malignant', output_dir) |