import matplotlib.pyplot as plt import numpy as np import os, cv2 import glob import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader import torch.optim as optim import torchvision.datasets as datasets import torchvision.transforms as transforms from torchvision.utils import make_grid, save_image from gan_losses import get_gan_losses from PIL import Image import torchvision.utils as vutils device = torch.device("cuda" if torch.cuda.is_available() else "cpu") """## Load Data""" # data_variance = np.var(training_data.data / 255.0) data_variance = 1 def mkdir(dir): if not os.path.exists(dir): os.makedirs(dir) def read_image(img_path): img = cv2.imread(img_path) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img = img / 255.0 return img class VectorQuantizer(nn.Module): def __init__(self, num_embeddings, embedding_dim, commitment_cost): super(VectorQuantizer, self).__init__() self._embedding_dim = embedding_dim self._num_embeddings = num_embeddings #codebook self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim) self._embedding.weight.data.uniform_(-1/self._num_embeddings, 1/self._num_embeddings) self._commitment_cost = commitment_cost def forward(self, inputs): # convert inputs from BCHW -> BHWC inputs = inputs.permute(0, 2, 3, 1).contiguous() input_shape = inputs.shape # Flatten input flat_input = inputs.view(-1, self._embedding_dim) # Calculate distances distances = (torch.sum(flat_input**2, dim=1, keepdim=True) + torch.sum(self._embedding.weight**2, dim=1) - 2 * torch.matmul(flat_input, self._embedding.weight.t())) # Encoding encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1) encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device) encodings.scatter_(1, encoding_indices, 1) # Quantize and unflatten quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape) # Loss e_latent_loss = F.mse_loss(quantized.detach(), inputs) q_latent_loss = F.mse_loss(quantized, inputs.detach()) loss = q_latent_loss + self._commitment_cost * e_latent_loss quantized = inputs + (quantized - inputs).detach() avg_probs = torch.mean(encodings, dim=0) perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) # convert quantized from BHWC -> BCHW return loss, quantized.permute(0, 3, 1, 2).contiguous(), perplexity, encoding_indices class VectorQuantizerEMA(nn.Module): def __init__(self, num_embeddings, embedding_dim, commitment_cost, decay, epsilon=1e-5): super(VectorQuantizerEMA, self).__init__() self._embedding_dim = embedding_dim self._num_embeddings = num_embeddings self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim) self._embedding.weight.data.normal_() self._commitment_cost = commitment_cost self.register_buffer('_ema_cluster_size', torch.zeros(num_embeddings)) self._ema_w = nn.Parameter(torch.Tensor(num_embeddings, self._embedding_dim)) self._ema_w.data.normal_() self._decay = decay self._epsilon = epsilon def forward(self, inputs): # convert inputs from BCHW -> BHWC inputs = inputs.permute(0, 2, 3, 1).contiguous() input_shape = inputs.shape # Flatten input flat_input = inputs.view(-1, self._embedding_dim) # Calculate distances distances = (torch.sum(flat_input**2, dim=1, keepdim=True) + torch.sum(self._embedding.weight**2, dim=1) - 2 * torch.matmul(flat_input, self._embedding.weight.t())) # Encoding encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1) # encoding_indices[encoding_indices == 3] = 4 # 1 means background, 2 means epithelial cells, 4 means connective, 3 means neutrophil, 5 means plasma, 6 lymphocytes encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device) encodings.scatter_(1, encoding_indices, 1) # Quantize and unflatten quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape) # Use EMA to update the embedding vectors if self.training: self._ema_cluster_size = self._ema_cluster_size * self._decay + \ (1 - self._decay) * torch.sum(encodings, 0) # Laplace smoothing of the cluster size n = torch.sum(self._ema_cluster_size.data) self._ema_cluster_size = ( (self._ema_cluster_size + self._epsilon) / (n + self._num_embeddings * self._epsilon) * n) dw = torch.matmul(encodings.t(), flat_input) self._ema_w = nn.Parameter(self._ema_w * self._decay + (1 - self._decay) * dw) self._embedding.weight = nn.Parameter(self._ema_w / self._ema_cluster_size.unsqueeze(1)) # Loss e_latent_loss = F.mse_loss(quantized.detach(), inputs) loss = self._commitment_cost * e_latent_loss # Straight Through Estimator quantized = inputs + (quantized - inputs).detach() avg_probs = torch.mean(encodings, dim=0) perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) # convert quantized from BHWC -> BCHW return loss, quantized.permute(0, 3, 1, 2).contiguous(), perplexity, encoding_indices class Residual(nn.Module): def __init__(self, in_channels, num_hiddens, num_residual_hiddens): super(Residual, self).__init__() self._block = nn.Sequential( nn.ReLU(True), nn.Conv2d(in_channels=in_channels, out_channels=num_residual_hiddens, kernel_size=3, stride=1, padding=1, bias=False), nn.ReLU(True), nn.Conv2d(in_channels=num_residual_hiddens, out_channels=num_hiddens, kernel_size=1, stride=1, bias=False) ) def forward(self, x): return x + self._block(x) class ResidualStack(nn.Module): def __init__(self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens): super(ResidualStack, self).__init__() self._num_residual_layers = num_residual_layers self._layers = nn.ModuleList([Residual(in_channels, num_hiddens, num_residual_hiddens) for _ in range(self._num_residual_layers)]) def forward(self, x): for i in range(self._num_residual_layers): x = self._layers[i](x) return F.relu(x) class Encoder(nn.Module): def __init__(self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens, embedding_dim): super(Encoder, self).__init__() self._conv_1 = nn.Conv2d(in_channels=in_channels, out_channels=num_hiddens//2, kernel_size=4, stride=2, padding=1) self._conv_2 = nn.Conv2d(in_channels=num_hiddens//2, out_channels=num_hiddens, kernel_size=4, stride=2, padding=1) self._conv_3 = nn.Conv2d(in_channels=num_hiddens, out_channels=num_hiddens, kernel_size=3, stride=1, padding=1) self._residual_stack = ResidualStack(in_channels=num_hiddens, num_hiddens=num_hiddens, num_residual_layers=num_residual_layers, num_residual_hiddens=num_residual_hiddens) self._pre_vq_conv = nn.Conv2d(in_channels=num_hiddens, out_channels=embedding_dim, kernel_size=1, stride=1) self.apply_tanh = nn.Tanh() def forward(self, inputs): x = self._conv_1(inputs) x = F.relu(x) x = self._conv_2(x) x = F.relu(x) x = self._conv_3(x) x = self._residual_stack(x) x = self._pre_vq_conv(x) return x class Decoder(nn.Module): def __init__(self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens): super(Decoder, self).__init__() self._conv_1 = nn.Conv2d(in_channels=in_channels, out_channels=num_hiddens, kernel_size=3, stride=1, padding=1) self._residual_stack = ResidualStack(in_channels=num_hiddens, num_hiddens=num_hiddens, num_residual_layers=num_residual_layers, num_residual_hiddens=num_residual_hiddens) self._conv_trans_1 = nn.ConvTranspose2d(in_channels=num_hiddens, out_channels=num_hiddens//2, kernel_size=4, stride=2, padding=1) self._conv_trans_2 = nn.ConvTranspose2d(in_channels=num_hiddens//2, out_channels=3, kernel_size=4, stride=2, padding=1) self.apply_tanh = nn.Tanh() def forward(self, inputs): x = self._conv_1(inputs) x = self._residual_stack(x) x = self._conv_trans_1(x) x = F.relu(x) x = self._conv_trans_2(x) return self.apply_tanh(x) class VQModel(nn.Module): def __init__(self, num_hiddens, num_residual_layers, num_residual_hiddens, num_embeddings, embedding_dim, commitment_cost, decay=0): super(VQModel, self).__init__() self._encoder = Encoder(3, num_hiddens, num_residual_layers, num_residual_hiddens, embedding_dim) if decay > 0.0: self._vq_vae = VectorQuantizerEMA(num_embeddings, embedding_dim, commitment_cost, decay) else: self._vq_vae = VectorQuantizer(num_embeddings, embedding_dim, commitment_cost) self._decoder = Decoder(embedding_dim, num_hiddens, num_residual_layers, num_residual_hiddens) def forward(self, x): z = self._encoder(x) loss, quantized, perplexity, _ = self._vq_vae(z) x_recon = self._decoder(quantized) return loss, x_recon, perplexity def save_generated_images(image_names, batch_images, ind, mode, type): current_output_dir = os.path.join(output_dir, mode, type) mkdir(current_output_dir) num_images = batch_images.shape[0] for i in range(0,num_images): save_image(batch_images[i], os.path.join(current_output_dir,image_names[i])) def generate_images_from_diffusion_latents(model, latents_path, output_dir): latent_paths = glob.glob(os.path.join(latents_path, "*.pt")) for latent_path in latent_paths: latent = torch.load(latent_path).cuda() latent = latent.detach() _, quantized_latent, _, _ = model._vq_vae(latent) image = model._decoder(quantized_latent) image_name = os.path.basename(latent_path).split(".")[0]+".png" save_image(image, os.path.join(output_dir, image_name)) class UNetDown(nn.Module): def __init__(self, in_size, out_size, normalize=True, dropout=0.0): super(UNetDown, self).__init__() layers = [nn.Conv2d(in_size, out_size, 4, 2, 1, bias=False)] if normalize: layers.append(nn.InstanceNorm2d(out_size)) layers.append(nn.LeakyReLU(0.2)) if dropout: layers.append(nn.Dropout(dropout)) self.model = nn.Sequential(*layers) def forward(self, x): return self.model(x) class UNetUp(nn.Module): def __init__(self, in_size, out_size, dropout=0.0): super(UNetUp, self).__init__() layers = [ nn.ConvTranspose2d(in_size, out_size, 4, 2, 1, bias=False), nn.InstanceNorm2d(out_size), nn.ReLU(inplace=True), ] if dropout: layers.append(nn.Dropout(dropout)) self.model = nn.Sequential(*layers) def forward(self, x, skip_input): x = self.model(x) x = torch.cat((x, skip_input), 1) return x class Pix2PixGenerator(nn.Module): def __init__(self, in_channels=3, out_channels=3): super(Pix2PixGenerator, self).__init__() self.down1 = UNetDown(in_channels, 64, normalize=False) self.down2 = UNetDown(64, 128) self.down3 = UNetDown(128, 256) self.down4 = UNetDown(256, 512, dropout=0.5) self.down5 = UNetDown(512, 512, dropout=0.5) self.down6 = UNetDown(512, 512, dropout=0.5) self.down7 = UNetDown(512, 512, dropout=0.5) self.down8 = UNetDown(512, 512, normalize=False, dropout=0.5) self.up1 = UNetUp(512, 512, dropout=0.5) self.up2 = UNetUp(1024, 512, dropout=0.5) self.up3 = UNetUp(1024, 512, dropout=0.5) self.up4 = UNetUp(1024, 512, dropout=0.5) self.up5 = UNetUp(1024, 256) self.up6 = UNetUp(512, 128) self.up7 = UNetUp(256, 64) self.final = nn.Sequential( nn.Upsample(scale_factor=2), nn.ZeroPad2d((1, 0, 1, 0)), nn.Conv2d(128, out_channels, 4, padding=1), nn.Tanh(), ) def forward(self, x): # U-Net generator with skip connections from encoder to decoder d1 = self.down1(x) d2 = self.down2(d1) d3 = self.down3(d2) d4 = self.down4(d3) d5 = self.down5(d4) d6 = self.down6(d5) d7 = self.down7(d6) d8 = self.down8(d7) u1 = self.up1(d8, d7) u2 = self.up2(u1, d6) u3 = self.up3(u2, d5) u4 = self.up4(u3, d4) u5 = self.up5(u4, d3) u6 = self.up6(u5, d2) u7 = self.up7(u6, d1) return self.final(u7) batch_size = 32 #Keep 16 for good results num_training_updates = 30000 num_hiddens = 32 #Original: 128 , 32 used for masks num_residual_hiddens = 32 num_residual_layers = 2 #Original was 2 embedding_dim = 3 num_embeddings = 2 #number of codebook vectors commitment_cost = 0.25 decay = 0.99 model_name = "dp_bimask_2dim_1024size_tanhindecoder.pt" def create_mask(model_dir, latents_path, final_output_dir): model = VQModel(num_hiddens, num_residual_layers, num_residual_hiddens, num_embeddings, embedding_dim, commitment_cost, decay).to(device) model.load_state_dict(torch.load(os.path.join(model_dir,model_name))) model.eval() mkdir(final_output_dir) generate_images_from_diffusion_latents(model=model, latents_path=latents_path, output_dir=final_output_dir)