srijaydeshpande commited on
Commit
c6c8618
·
verified ·
1 Parent(s): e87308d

Upload gan_losses.py

Browse files
Files changed (1) hide show
  1. gan_losses.py +213 -0
gan_losses.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ #
3
+ # Copyright 2018 Google LLC
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import torch
18
+ import torch.nn.functional as F
19
+ from torchvision import models
20
+ from torch import nn
21
+
22
+
23
+ def get_gan_losses(gan_type):
24
+ """
25
+ Returns the generator and discriminator loss for a particular GAN type.
26
+
27
+ The returned functions have the following API:
28
+ loss_g = g_loss(scores_fake)
29
+ loss_d = d_loss(scores_real, scores_fake)
30
+ """
31
+ if gan_type == 'gan':
32
+ return gan_g_loss, gan_d_loss
33
+ elif gan_type == 'wgan':
34
+ return wgan_g_loss, wgan_d_loss
35
+ elif gan_type == 'lsgan':
36
+ return lsgan_g_loss, lsgan_d_loss
37
+ else:
38
+ raise ValueError('Unrecognized GAN type "%s"' % gan_type)
39
+
40
+
41
+ def bce_loss(input, target):
42
+ """
43
+ Numerically stable version of the binary cross-entropy loss function.
44
+
45
+ As per https://github.com/pytorch/pytorch/issues/751
46
+ See the TensorFlow docs for a derivation of this formula:
47
+ https://www.tensorflow.org/api_docs/python/tf/nn/sigmoid_cross_entropy_with_logits
48
+
49
+ Inputs:
50
+ - input: PyTorch Tensor of shape (N, ) giving scores.
51
+ - target: PyTorch Tensor of shape (N,) containing 0 and 1 giving targets.
52
+
53
+ Returns:
54
+ - A PyTorch Tensor containing the mean BCE loss over the minibatch of
55
+ input data.
56
+ """
57
+ neg_abs = -input.abs()
58
+ loss = input.clamp(min=0) - input * target + (1 + neg_abs.exp()).log()
59
+ return loss.mean()
60
+
61
+
62
+ def _make_targets(x, y):
63
+ """
64
+ Inputs:
65
+ - x: PyTorch Tensor
66
+ - y: Python scalar
67
+
68
+ Outputs:
69
+ - out: PyTorch Variable with same shape and dtype as x, but filled with y
70
+ """
71
+ return torch.full_like(x, y)
72
+
73
+
74
+ def gan_g_loss(scores_fake):
75
+ """
76
+ Input:
77
+ - scores_fake: Tensor of shape (N,) containing scores for fake samples
78
+
79
+ Output:
80
+ - loss: Variable of shape (,) giving GAN generator loss
81
+ """
82
+ if scores_fake.dim() > 1:
83
+ scores_fake = scores_fake.view(-1)
84
+ y_fake = _make_targets(scores_fake, 1)
85
+ return bce_loss(scores_fake, y_fake)
86
+
87
+
88
+ def gan_d_loss(scores_real, scores_fake):
89
+ """
90
+ Input:
91
+ - scores_real: Tensor of shape (N,) giving scores for real samples
92
+ - scores_fake: Tensor of shape (N,) giving scores for fake samples
93
+
94
+ Output:
95
+ - loss: Tensor of shape (,) giving GAN discriminator loss
96
+ """
97
+ assert scores_real.size() == scores_fake.size()
98
+ if scores_real.dim() > 1:
99
+ scores_real = scores_real.view(-1)
100
+ scores_fake = scores_fake.view(-1)
101
+ y_real = _make_targets(scores_real, 1)
102
+ y_fake = _make_targets(scores_fake, 0)
103
+ loss_real = bce_loss(scores_real, y_real)
104
+ loss_fake = bce_loss(scores_fake, y_fake)
105
+ return loss_real + loss_fake
106
+
107
+
108
+ def wgan_g_loss(scores_fake):
109
+ """
110
+ Input:
111
+ - scores_fake: Tensor of shape (N,) containing scores for fake samples
112
+
113
+ Output:
114
+ - loss: Tensor of shape (,) giving WGAN generator loss
115
+ """
116
+ return -scores_fake.mean()
117
+
118
+
119
+ def wgan_d_loss(scores_real, scores_fake):
120
+ """
121
+ Input:
122
+ - scores_real: Tensor of shape (N,) giving scores for real samples
123
+ - scores_fake: Tensor of shape (N,) giving scores for fake samples
124
+
125
+ Output:
126
+ - loss: Tensor of shape (,) giving WGAN discriminator loss
127
+ """
128
+ return scores_fake.mean() - scores_real.mean()
129
+
130
+
131
+ def lsgan_g_loss(scores_fake):
132
+ if scores_fake.dim() > 1:
133
+ scores_fake = scores_fake.view(-1)
134
+ y_fake = _make_targets(scores_fake, 1)
135
+ return F.mse_loss(scores_fake.sigmoid(), y_fake)
136
+
137
+
138
+ def lsgan_d_loss(scores_real, scores_fake):
139
+ assert scores_real.size() == scores_fake.size()
140
+ if scores_real.dim() > 1:
141
+ scores_real = scores_real.view(-1)
142
+ scores_fake = scores_fake.view(-1)
143
+ y_real = _make_targets(scores_real, 1)
144
+ y_fake = _make_targets(scores_fake, 0)
145
+ loss_real = F.mse_loss(scores_real.sigmoid(), y_real)
146
+ loss_fake = F.mse_loss(scores_fake.sigmoid(), y_fake)
147
+ return loss_real + loss_fake
148
+
149
+
150
+ def gradient_penalty(x_real, x_fake, f, gamma=1.0):
151
+ N = x_real.size(0)
152
+ device, dtype = x_real.device, x_real.dtype
153
+ eps = torch.randn(N, 1, 1, 1, device=device, dtype=dtype)
154
+ x_hat = eps * x_real + (1 - eps) * x_fake
155
+ x_hat_score = f(x_hat)
156
+ if x_hat_score.dim() > 1:
157
+ x_hat_score = x_hat_score.view(x_hat_score.size(0), -1).mean(dim=1)
158
+ x_hat_score = x_hat_score.sum()
159
+ grad_x_hat, = torch.autograd.grad(x_hat_score, x_hat, create_graph=True)
160
+ grad_x_hat_norm = grad_x_hat.contiguous().view(N, -1).norm(p=2, dim=1)
161
+ gp_loss = (grad_x_hat_norm - gamma).pow(2).div(gamma * gamma).mean()
162
+ return gp_loss
163
+
164
+ # VGG Features matching
165
+ class Vgg19(torch.nn.Module):
166
+ def __init__(self, requires_grad=False):
167
+ super(Vgg19, self).__init__()
168
+ vgg_pretrained_features = models.vgg19(pretrained=True).features
169
+ self.slice1 = torch.nn.Sequential()
170
+ self.slice2 = torch.nn.Sequential()
171
+ self.slice3 = torch.nn.Sequential()
172
+ self.slice4 = torch.nn.Sequential()
173
+ self.slice5 = torch.nn.Sequential()
174
+ for x in range(2):
175
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
176
+ for x in range(2, 7):
177
+ self.slice2.add_module(str(x), vgg_pretrained_features[x])
178
+ for x in range(7, 12):
179
+ self.slice3.add_module(str(x), vgg_pretrained_features[x])
180
+ for x in range(12, 21):
181
+ self.slice4.add_module(str(x), vgg_pretrained_features[x])
182
+ for x in range(21, 30):
183
+ self.slice5.add_module(str(x), vgg_pretrained_features[x])
184
+ if not requires_grad:
185
+ for param in self.parameters():
186
+ param.requires_grad = False
187
+
188
+ def forward(self, X):
189
+ h_relu1 = self.slice1(X)
190
+ h_relu2 = self.slice2(h_relu1)
191
+ h_relu3 = self.slice3(h_relu2)
192
+ h_relu4 = self.slice4(h_relu3)
193
+ h_relu5 = self.slice5(h_relu4)
194
+ out = [h_relu5, h_relu2, h_relu3, h_relu4, h_relu5]
195
+ return out
196
+
197
+
198
+ class VGGLoss(nn.Module):
199
+ def __init__(self):
200
+ super(VGGLoss, self).__init__()
201
+ if torch.cuda.is_available():
202
+ self.vgg = Vgg19().cuda()
203
+ else:
204
+ self.vgg = Vgg19()
205
+ self.criterion = nn.L1Loss()
206
+ self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0]
207
+
208
+ def forward(self, x, y):
209
+ x_vgg, y_vgg = self.vgg(x), self.vgg(y)
210
+ loss = 0
211
+ for i in range(len(x_vgg)):
212
+ loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach())
213
+ return loss