neverix commited on
Commit
dee645c
·
0 Parent(s):

Initial commit

Browse files
Files changed (5) hide show
  1. .gitignore +3 -0
  2. app.py +35 -0
  3. pulsar_clip.py +222 -0
  4. requirements.txt +6 -0
  5. utils.py +71 -0
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ .idea/
2
+ **/__pycache__/
3
+ flagged/
app.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pulsar_clip import PulsarCLIP, CONFIG_SPEC
2
+ from datetime import datetime
3
+ import gradio as gr
4
+
5
+
6
+ def generate(*args):
7
+ pc = PulsarCLIP(dict([(k, t(v) if not isinstance(t, (tuple, list)) else v)
8
+ for v, (k, v0, t) in zip(args, CONFIG_SPEC)]))
9
+ frames = []
10
+ for image in pc.generate():
11
+ frames.append(image)
12
+ from tqdm.auto import tqdm
13
+ from subprocess import Popen, PIPE
14
+ fps = 30
15
+ video_path = f"{datetime.strftime(datetime.now())}.mp4"
16
+ if frames:
17
+ p = Popen((f"ffmpeg -y -f image2pipe -vcodec png -r {fps} -i - -vcodec libx264 -r {fps} "
18
+ f"-pix_fmt yuv420p -crf 17 -preset fast ").split() + [str(video_path)], stdin=PIPE)
19
+ for im in tqdm(frames):
20
+ im.save(p.stdin, "PNG")
21
+ p.stdin.close()
22
+ p.wait()
23
+ return video_path
24
+
25
+
26
+ def main():
27
+ gr.Interface(inputs=[
28
+ (gr.inputs.Number(label=k, default=v0) if t in (float, int) else
29
+ gr.inputs.Checkbox(label=k, default=v0) if t == bool else gr.inputs.Textbox(label=k, default=v0) if t == str
30
+ else gr.inputs.Dropdown(label=k, default=v0, choices=t) if isinstance(t, (tuple, list)) else 1/0)
31
+ for k, v0, t in CONFIG_SPEC], outputs=gr.outputs.Video(), fn=generate).launch()
32
+
33
+
34
+ if __name__ == '__main__':
35
+ main()
pulsar_clip.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import set_seed
2
+ from tqdm.auto import trange
3
+ from PIL import Image
4
+ import numpy as np
5
+ import random
6
+ import utils
7
+ import torch
8
+
9
+
10
+ CONFIG_SPEC = [
11
+ ("text", "A cloud at dawn", str),
12
+ ("iterations", 5000, int),
13
+ ("turns", 4, int),
14
+ ("showoff", 5000, int),
15
+ ("seed", 12, int),
16
+ ("focal_length", 0.1, float),
17
+ ("plane_width", 0.1, float),
18
+ ("shade_strength", 0.25, float),
19
+ ("gamma", 0.5, float),
20
+ ("max_depth", 7, float),
21
+ ("lr", 0.5, float),
22
+ ("offset", 5, float),
23
+ ("offset_random", 0.75, float),
24
+ ("xyz_random", 0.25, float),
25
+ ("altitude_range", 0.3, float),
26
+ ("augments", 4, int),
27
+ ("show_every", 50, int),
28
+ ("epochs", 1, int),
29
+ ("w", 224, int),
30
+ ("h", 224, int),
31
+ ("num_objects", 256, int),
32
+ #@markdown CLIP loss type, might improve the results
33
+ ("loss_type", "spherical", ("spherical", "cosine")),
34
+ #@markdown CLIP loss weight
35
+ ("clip_weight", 1.0, float), #@param {type: "number"}
36
+ #@markdown Number of dimensions. 0 is for point clouds (default), 1 will make
37
+ #@markdown strokes, 2 will make planes, 3 produces little cubes
38
+ ("ndim", 0, (0, 1, 2, 3)), #@param {type: "integer"}
39
+
40
+ #@markdown Opacity scale:
41
+ ("min_opacity", 1e-4, float), #@param {type: "number"}
42
+ ("max_opacity", 1.0, float), #@param {type: "number"}
43
+ ("log_opacity", False, bool), #@param {type: "boolean"}
44
+
45
+ ("min_radius", 0.030, float),
46
+ ("max_radius", 0.070, float),
47
+ ("log_radius", False, bool),
48
+
49
+ # TODO dynamically decide bezier_res
50
+ #@markdown Bezier resolution: how many points a line/plane/cube will have. Not applicable to points
51
+ ("bezier_res", 8, int), #@param {type: "integer"}
52
+ #@markdown Maximum scale of parameters: position, velocity, acceleration
53
+ ("pos_scale", 0.4, float), #@param {type: "number"}
54
+ ("vel_scale", 0.15, float), #@param {type: "number"}
55
+ ("acc_scale", 0.15, float), #@param {type: "number"}
56
+
57
+ #@markdown Scale of each individual 3D object. Master control for velocity and acceleration scale.
58
+ ("scale", 1, float), #@param {type: "number"}
59
+ ]
60
+
61
+
62
+ # TODO: one day separate the config into multiple parts and split this megaobject into multiple objects
63
+ class PulsarCLIP(object):
64
+ def __init__(self, args):
65
+ args = DotDict(**args)
66
+ set_seed(args.seed)
67
+ self.args = args
68
+ self.device = args.get("device", "cuda" if torch.cuda.is_available() else "cpu")
69
+ # Defer the import so that we can import `pulsar_clip` and then install `pytorch3d`
70
+ import pytorch3d.renderer.points.pulsar as ps
71
+ self.ndim = int(self.args.ndim)
72
+ self.renderer = ps.Renderer(self.args.w, self.args.h,
73
+ self.args.num_objects * (self.args.bezier_res ** self.ndim)).to(self.device)
74
+ self.bezier_pos = torch.nn.Parameter(torch.randn((args.num_objects, 4)).to(self.device))
75
+ self.bezier_vel = torch.nn.Parameter(torch.randn((args.num_objects, 3 * self.ndim)).to(self.device))
76
+ self.bezier_acc = torch.nn.Parameter(torch.randn((args.num_objects, 3 * self.ndim)).to(self.device))
77
+ self.bezier_col = torch.nn.Parameter(torch.randn((args.num_objects, 4 * (1 + self.ndim))).to(self.device))
78
+ self.optimizer = torch.optim.Adam([dict(params=[self.bezier_col], lr=5e-1 * args.lr),
79
+ dict(params=[self.bezier_pos], lr=1e-1 * args.lr),
80
+ dict(params=[self.bezier_vel, self.bezier_acc], lr=5e-2 * args.lr),
81
+ ])
82
+ self.model_clip, self.preprocess_clip = utils.load_clip()
83
+ self.model_clip.visual.requires_grad_(False)
84
+ self.scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(self.optimizer,
85
+ int(self.args.iterations
86
+ / self.args.augments
87
+ / self.args.epochs))
88
+ import clip
89
+ self.txt_emb = self.model_clip.encode_text(clip.tokenize([self.args.text]).to(self.device))[0].detach()
90
+ self.txt_emb = torch.nn.functional.normalize(self.txt_emb, dim=-1)
91
+
92
+ def get_points(self):
93
+ if self.ndim > 0:
94
+ bezier_ts = torch.stack(torch.meshgrid(
95
+ (torch.linspace(0, 1, self.args.bezier_res, device=self.device),) * self.ndim), dim=0
96
+ ).unsqueeze(1).repeat((1, self.args.num_objects) + (1,) * self.ndim).unsqueeze(-1)
97
+
98
+ def interpolate_3D(pos, vel=0.0, acc=0.0, pos_scale=None, vel_scale=None, acc_scale=None, scale=None):
99
+ pos_scale = self.args.pos_scale if pos_scale is None else pos_scale
100
+ vel_scale = self.args.vel_scale if vel_scale is None else vel_scale
101
+ acc_scale = self.args.acc_scale if acc_scale is None else acc_scale
102
+ scale = self.args.scale if scale is None else scale
103
+ if self.ndim == 0:
104
+ return pos * pos_scale
105
+ result = 0.0
106
+ s = pos.shape[-1]
107
+ assert s * self.ndim == vel.shape[-1] == acc.shape[-1]
108
+ # O(dim) sequential lol
109
+ for d, bezier_t in zip(range(self.ndim), bezier_ts): # TODO replace with fused dimension operation
110
+ result = (result
111
+ + torch.tanh(vel[..., d * s:(d + 1) * s]).view(
112
+ (-1,) + (1,) * self.ndim + (s,)) * vel_scale * bezier_t
113
+ + torch.tanh(acc[..., d * s:(d + 1) * s]).view(
114
+ (-1,) + (1,) * self.ndim + (s,)) * acc_scale * bezier_t.pow(2))
115
+ result = (result * scale
116
+ + torch.tanh(pos[..., :s]).view((-1,) + (1,) * self.ndim + (s,)) * pos_scale).view(-1, s)
117
+ return result
118
+
119
+ vert_pos = interpolate_3D(self.bezier_pos[..., :3], self.bezier_vel, self.bezier_acc)
120
+ vert_col = interpolate_3D(self.bezier_col[..., :4],
121
+ self.bezier_col[..., 4:4 + 4 * self.ndim],
122
+ self.bezier_col[..., -4 * self.ndim:])
123
+
124
+ to_bezier = lambda x: x.view((-1,) + (1,) * self.ndim + (x.shape[-1],)).repeat(
125
+ (1,) + (self.args.bezier_res,) * self.ndim + (1,)).reshape(-1, x.shape[-1])
126
+ rescale = lambda x, a, b, is_log=False: (torch.exp(x
127
+ * np.log(b / a)
128
+ + np.log(a))) if is_log else x * (b - a) + a
129
+ return (
130
+ vert_pos,
131
+ torch.sigmoid(vert_col[..., :3]),
132
+ rescale(
133
+ torch.sigmoid(to_bezier(self.bezier_pos[..., -1:])[..., 0]),
134
+ self.args.min_radius, self.args.max_radius, is_log=self.args.log_radius
135
+ ),
136
+ rescale(torch.sigmoid(vert_col[..., -1]),
137
+ self.args.min_opacity, self.args.max_opacity, is_log=self.args.log_opacity))
138
+
139
+ def camera(self, angle, altitude=0.0, offset=None, use_random=True, offset_random=None,
140
+ xyz_random=None, focal_length=None, plane_width=None):
141
+ if offset is None:
142
+ offset = self.args.offset
143
+ if xyz_random is None:
144
+ xyz_random = self.args.xyz_random
145
+ if focal_length is None:
146
+ focal_length = self.args.focal_length
147
+ if plane_width is None:
148
+ plane_width = self.args.plane_width
149
+ if offset_random is None:
150
+ offset_random = self.args.offset_random
151
+ device = self.device
152
+ offset = offset + np.random.normal() * offset_random * int(use_random)
153
+ position = torch.tensor([0, 0, -offset], dtype=torch.float)
154
+ position = utils.rotate_axis(position, altitude, 0)
155
+ position = utils.rotate_axis(position, angle, 1)
156
+ position = position + torch.randn(3) * xyz_random * int(use_random)
157
+ return torch.tensor([position[0], position[1], position[2],
158
+ altitude, angle, 0,
159
+ focal_length, plane_width], dtype=torch.float, device=device)
160
+
161
+
162
+ def render(self, cam_params=None):
163
+ if cam_params is None:
164
+ cam_params = self.camera(0, 0)
165
+ vert_pos, vert_col, radius, opacity = self.get_points()
166
+
167
+ rgb = self.renderer(vert_pos, vert_col, radius, cam_params,
168
+ self.args.gamma, self.args.max_depth, opacity=opacity)
169
+ opacity = self.renderer(vert_pos, vert_col * 0, radius, cam_params,
170
+ self.args.gamma, self.args.max_depth, opacity=opacity)
171
+ return rgb, opacity
172
+
173
+ def random_view_render(self):
174
+ angle = random.uniform(0, np.pi * 2)
175
+ altitude = random.uniform(-self.args.altitude_range / 2, self.args.altitude_range / 2)
176
+ cam_params = self.camera(angle, altitude)
177
+ result, alpha = self.render(cam_params)
178
+ back = torch.zeros_like(result)
179
+ s = back.shape
180
+ for j in range(s[-1]):
181
+ n = random.choice([7, 14, 28])
182
+ back[..., j] = utils.rand_perlin_2d_octaves(s[:-1], (n, n)).clip(-0.5, 0.5) + 0.5
183
+ result = result * (1 - alpha) + back * alpha
184
+ return result
185
+
186
+
187
+ def generate(self):
188
+ self.optimizer.zero_grad()
189
+ try:
190
+ for i in trange(self.args.iterations + self.args.showoff):
191
+ if i < self.args.iterations:
192
+ result = self.random_view_render()
193
+ img_emb = self.model_clip.encode_image(
194
+ self.preprocess_clip(result.permute(2, 0, 1)).unsqueeze(0).clamp(0., 1.))
195
+ img_emb = torch.nn.functional.normalize(img_emb, dim=-1)
196
+ if self.args.loss_type == "spherical":
197
+ clip_loss = (img_emb - self.txt_emb).norm(dim=-1).div(2).arcsin().pow(2).mul(2).mean()
198
+ elif self.args.loss_type == "cosine":
199
+ clip_loss = (1 - img_emb @ self.txt_emb.T).mean()
200
+ else:
201
+ raise NotImplementedError(f"CLIP loss type not supported: {self.args.loss_type}")
202
+ loss = clip_loss * self.args.clip_weight + (0 and ...) # TODO add more loss types
203
+ loss.backward()
204
+ if i % self.args.augments == self.args.augments - 1:
205
+ self.optimizer.step()
206
+ self.optimizer.zero_grad()
207
+ try:
208
+ self.scheduler.step()
209
+ except AttributeError:
210
+ pass
211
+ if i % self.args.show_every == 0:
212
+ cam_params = self.camera(i / self.args.iterations * np.pi * 2 * self.args.turns, use_random=False)
213
+ img_show, _ = self.render(cam_params)
214
+ img = Image.fromarray((img_show.cpu().detach().numpy() * 255).astype(np.uint8))
215
+ yield img
216
+ except KeyboardInterrupt:
217
+ pass
218
+
219
+
220
+ class DotDict(dict):
221
+ def __getattr__(self, item):
222
+ return self.__getitem__(item)
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ pytorch3d==0.6.2
2
+ transformers==4.10.3
3
+ torch==1.11.0+cu113
4
+ torchvision==0.12.0+cu113
5
+ clip
6
+ gradio
utils.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import torch
3
+ import math
4
+
5
+
6
+ def rotate_axis(x, add_angle=0, axis=1): # TODO Replace with a rotation matrix # But this is more fun
7
+ axes = list(range(3))
8
+ axes.remove(axis)
9
+ ax1, ax2 = axes
10
+ angle = torch.atan2(x[..., ax1], x[..., ax2])
11
+ if isinstance(add_angle, torch.Tensor):
12
+ while add_angle.ndim < angle.ndim:
13
+ add_angle = add_angle.unsqueeze(-1)
14
+ angle = angle + add_angle
15
+ dist = x.norm(dim=-1)
16
+ t = []
17
+ _, t = zip(*sorted([
18
+ (axis, x[..., axis]),
19
+ (ax1, torch.sin(angle) * dist),
20
+ (ax2, torch.cos(angle) * dist),
21
+ ]))
22
+ return torch.stack(t, dim=-1)
23
+
24
+
25
+ noise_level = 0.5
26
+
27
+
28
+ # stolen from https://gist.github.com/ac1b097753f217c5c11bc2ff396e0a57
29
+ # ported from https://github.com/pvigier/perlin-numpy/blob/master/perlin2d.py
30
+ def rand_perlin_2d(shape, res, fade=lambda t: 6 * t ** 5 - 15 * t ** 4 + 10 * t ** 3):
31
+ delta = (res[0] / shape[0], res[1] / shape[1])
32
+ d = (shape[0] // res[0], shape[1] // res[1])
33
+
34
+ grid = torch.stack(torch.meshgrid(torch.arange(0, res[0], delta[0]), torch.arange(0, res[1], delta[1])), dim=-1) % 1
35
+ angles = 2 * math.pi * torch.rand(res[0] + 1, res[1] + 1)
36
+ gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1)
37
+
38
+ tile_grads = lambda slice1, slice2: gradients[slice1[0]:slice1[1], slice2[0]:slice2[1]].repeat_interleave(d[0],
39
+ 0).repeat_interleave(
40
+ d[1], 1)
41
+ dot = lambda grad, shift: (
42
+ torch.stack((grid[:shape[0], :shape[1], 0] + shift[0], grid[:shape[0], :shape[1], 1] + shift[1]),
43
+ dim=-1) * grad[:shape[0], :shape[1]]).sum(dim=-1)
44
+
45
+ n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0])
46
+ n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0])
47
+ n01 = dot(tile_grads([0, -1], [1, None]), [0, -1])
48
+ n11 = dot(tile_grads([1, None], [1, None]), [-1, -1])
49
+ t = fade(grid[:shape[0], :shape[1]])
50
+ return math.sqrt(2) * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1])
51
+
52
+
53
+ def rand_perlin_2d_octaves(shape, res, octaves=1, persistence=0.5):
54
+ noise = torch.zeros(shape)
55
+ frequency = 1
56
+ amplitude = 1
57
+ for _ in range(octaves):
58
+ noise += amplitude * rand_perlin_2d(shape, (frequency * res[0], frequency * res[1]))
59
+ frequency *= 2
60
+ amplitude *= persistence
61
+ noise *= random.random() - noise_level # haha
62
+ noise += random.random() - noise_level # haha x2
63
+ return noise
64
+
65
+
66
+ def load_clip(model_name="ViT-B/16", device="cuda:0" if torch.cuda.is_available() else "cpu"):
67
+ import clip
68
+ model, preprocess = clip.load(model_name, device=device, jit=False)
69
+ if len(preprocess.transforms) > 4:
70
+ preprocess.transforms = preprocess.transforms[-1:]
71
+ return model, preprocess