DiffusionGenerator / gan_losses.py
srijaydeshpande's picture
Upload gan_losses.py
c6c8618 verified
#!/usr/bin/python
#
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn.functional as F
from torchvision import models
from torch import nn
def get_gan_losses(gan_type):
"""
Returns the generator and discriminator loss for a particular GAN type.
The returned functions have the following API:
loss_g = g_loss(scores_fake)
loss_d = d_loss(scores_real, scores_fake)
"""
if gan_type == 'gan':
return gan_g_loss, gan_d_loss
elif gan_type == 'wgan':
return wgan_g_loss, wgan_d_loss
elif gan_type == 'lsgan':
return lsgan_g_loss, lsgan_d_loss
else:
raise ValueError('Unrecognized GAN type "%s"' % gan_type)
def bce_loss(input, target):
"""
Numerically stable version of the binary cross-entropy loss function.
As per https://github.com/pytorch/pytorch/issues/751
See the TensorFlow docs for a derivation of this formula:
https://www.tensorflow.org/api_docs/python/tf/nn/sigmoid_cross_entropy_with_logits
Inputs:
- input: PyTorch Tensor of shape (N, ) giving scores.
- target: PyTorch Tensor of shape (N,) containing 0 and 1 giving targets.
Returns:
- A PyTorch Tensor containing the mean BCE loss over the minibatch of
input data.
"""
neg_abs = -input.abs()
loss = input.clamp(min=0) - input * target + (1 + neg_abs.exp()).log()
return loss.mean()
def _make_targets(x, y):
"""
Inputs:
- x: PyTorch Tensor
- y: Python scalar
Outputs:
- out: PyTorch Variable with same shape and dtype as x, but filled with y
"""
return torch.full_like(x, y)
def gan_g_loss(scores_fake):
"""
Input:
- scores_fake: Tensor of shape (N,) containing scores for fake samples
Output:
- loss: Variable of shape (,) giving GAN generator loss
"""
if scores_fake.dim() > 1:
scores_fake = scores_fake.view(-1)
y_fake = _make_targets(scores_fake, 1)
return bce_loss(scores_fake, y_fake)
def gan_d_loss(scores_real, scores_fake):
"""
Input:
- scores_real: Tensor of shape (N,) giving scores for real samples
- scores_fake: Tensor of shape (N,) giving scores for fake samples
Output:
- loss: Tensor of shape (,) giving GAN discriminator loss
"""
assert scores_real.size() == scores_fake.size()
if scores_real.dim() > 1:
scores_real = scores_real.view(-1)
scores_fake = scores_fake.view(-1)
y_real = _make_targets(scores_real, 1)
y_fake = _make_targets(scores_fake, 0)
loss_real = bce_loss(scores_real, y_real)
loss_fake = bce_loss(scores_fake, y_fake)
return loss_real + loss_fake
def wgan_g_loss(scores_fake):
"""
Input:
- scores_fake: Tensor of shape (N,) containing scores for fake samples
Output:
- loss: Tensor of shape (,) giving WGAN generator loss
"""
return -scores_fake.mean()
def wgan_d_loss(scores_real, scores_fake):
"""
Input:
- scores_real: Tensor of shape (N,) giving scores for real samples
- scores_fake: Tensor of shape (N,) giving scores for fake samples
Output:
- loss: Tensor of shape (,) giving WGAN discriminator loss
"""
return scores_fake.mean() - scores_real.mean()
def lsgan_g_loss(scores_fake):
if scores_fake.dim() > 1:
scores_fake = scores_fake.view(-1)
y_fake = _make_targets(scores_fake, 1)
return F.mse_loss(scores_fake.sigmoid(), y_fake)
def lsgan_d_loss(scores_real, scores_fake):
assert scores_real.size() == scores_fake.size()
if scores_real.dim() > 1:
scores_real = scores_real.view(-1)
scores_fake = scores_fake.view(-1)
y_real = _make_targets(scores_real, 1)
y_fake = _make_targets(scores_fake, 0)
loss_real = F.mse_loss(scores_real.sigmoid(), y_real)
loss_fake = F.mse_loss(scores_fake.sigmoid(), y_fake)
return loss_real + loss_fake
def gradient_penalty(x_real, x_fake, f, gamma=1.0):
N = x_real.size(0)
device, dtype = x_real.device, x_real.dtype
eps = torch.randn(N, 1, 1, 1, device=device, dtype=dtype)
x_hat = eps * x_real + (1 - eps) * x_fake
x_hat_score = f(x_hat)
if x_hat_score.dim() > 1:
x_hat_score = x_hat_score.view(x_hat_score.size(0), -1).mean(dim=1)
x_hat_score = x_hat_score.sum()
grad_x_hat, = torch.autograd.grad(x_hat_score, x_hat, create_graph=True)
grad_x_hat_norm = grad_x_hat.contiguous().view(N, -1).norm(p=2, dim=1)
gp_loss = (grad_x_hat_norm - gamma).pow(2).div(gamma * gamma).mean()
return gp_loss
# VGG Features matching
class Vgg19(torch.nn.Module):
def __init__(self, requires_grad=False):
super(Vgg19, self).__init__()
vgg_pretrained_features = models.vgg19(pretrained=True).features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
self.slice4 = torch.nn.Sequential()
self.slice5 = torch.nn.Sequential()
for x in range(2):
self.slice1.add_module(str(x), vgg_pretrained_features[x])
for x in range(2, 7):
self.slice2.add_module(str(x), vgg_pretrained_features[x])
for x in range(7, 12):
self.slice3.add_module(str(x), vgg_pretrained_features[x])
for x in range(12, 21):
self.slice4.add_module(str(x), vgg_pretrained_features[x])
for x in range(21, 30):
self.slice5.add_module(str(x), vgg_pretrained_features[x])
if not requires_grad:
for param in self.parameters():
param.requires_grad = False
def forward(self, X):
h_relu1 = self.slice1(X)
h_relu2 = self.slice2(h_relu1)
h_relu3 = self.slice3(h_relu2)
h_relu4 = self.slice4(h_relu3)
h_relu5 = self.slice5(h_relu4)
out = [h_relu5, h_relu2, h_relu3, h_relu4, h_relu5]
return out
class VGGLoss(nn.Module):
def __init__(self):
super(VGGLoss, self).__init__()
if torch.cuda.is_available():
self.vgg = Vgg19().cuda()
else:
self.vgg = Vgg19()
self.criterion = nn.L1Loss()
self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0]
def forward(self, x, y):
x_vgg, y_vgg = self.vgg(x), self.vgg(y)
loss = 0
for i in range(len(x_vgg)):
loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach())
return loss