liuganghuggingface commited on
Commit
959d452
·
verified ·
1 Parent(s): 61e07a1

Upload graph_decoder/conditions.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. graph_decoder/conditions.py +84 -0
graph_decoder/conditions.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import math
4
+ import torch.nn.functional as F
5
+
6
+ class TimestepEmbedder(nn.Module):
7
+ """
8
+ Embeds scalar timesteps into vector representations.
9
+ """
10
+ def __init__(self, hidden_size, frequency_embedding_size=256):
11
+ super().__init__()
12
+ self.mlp = nn.Sequential(
13
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
14
+ nn.SiLU(),
15
+ nn.Linear(hidden_size, hidden_size, bias=True),
16
+ )
17
+ self.frequency_embedding_size = frequency_embedding_size
18
+
19
+ @staticmethod
20
+ def timestep_embedding(t, dim, max_period=10000):
21
+ """
22
+ Create sinusoidal timestep embeddings.
23
+ :param t: a 1-D Tensor of N indices, one per batch element.
24
+ These may be fractional.
25
+ :param dim: the dimension of the output.
26
+ :param max_period: controls the minimum frequency of the embeddings.
27
+ :return: an (N, D) Tensor of positional embeddings.
28
+ """
29
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
30
+ half = dim // 2
31
+ freqs = torch.exp(
32
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
33
+ ).to(device=t.device)
34
+ args = t[:, None].float() * freqs[None]
35
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
36
+ if dim % 2:
37
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
38
+ return embedding
39
+
40
+ def forward(self, t):
41
+ t = t.view(-1)
42
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
43
+ t_emb = self.mlp(t_freq)
44
+ return t_emb
45
+
46
+ class ConditionEmbedder(nn.Module):
47
+ def __init__(self, input_size, hidden_size, dropout_prob, max_weight=1.0, sigma_factor=0.25):
48
+ super().__init__()
49
+ self.embedding_drop = nn.Embedding(input_size, hidden_size)
50
+
51
+ self.mlps = nn.ModuleList([
52
+ nn.Sequential(
53
+ nn.Linear(1, hidden_size, bias=True),
54
+ nn.Softmax(dim=1),
55
+ nn.Linear(hidden_size, hidden_size, bias=False)
56
+ ) for _ in range(input_size)
57
+ ])
58
+
59
+ self.hidden_size = hidden_size
60
+ self.dropout_prob = dropout_prob
61
+
62
+ def forward(self, labels, train, unconditioned):
63
+ embeddings = 0
64
+
65
+ for dim in range(labels.shape[1]):
66
+ label = labels[:, dim]
67
+ if unconditioned:
68
+ drop_ids = torch.ones_like(label).bool()
69
+ else:
70
+ drop_ids = torch.isnan(label)
71
+ if train:
72
+ random_tensor = torch.rand(label.shape).type_as(labels)
73
+ probability_mask = random_tensor < self.dropout_prob
74
+ drop_ids = drop_ids | probability_mask
75
+
76
+ label = label.unsqueeze(1)
77
+ embedding = torch.zeros((label.shape[0], self.hidden_size)).type_as(labels)
78
+ mlp_out = self.mlps[dim](label[~drop_ids])
79
+ embedding[~drop_ids] = mlp_out.type_as(embedding)
80
+ embedding[drop_ids] += self.embedding_drop.weight[dim]
81
+
82
+ embeddings += embedding
83
+
84
+ return embeddings