ntt123 commited on
Commit
256e1f6
·
1 Parent(s): 3e75f8e

inference code

Browse files
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.12
app.py CHANGED
@@ -1,7 +1,38 @@
1
  import gradio as gr
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  demo.launch()
 
1
  import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ from synthesize import synthesize
5
 
 
 
6
 
7
+ def text_to_speech(text, speaker_id, cfg_scale, num_sampling_steps):
8
+ audio, sample_rate = synthesize(
9
+ text=text,
10
+ duration_model_config="./train_duration_dit_s.yaml",
11
+ acoustic_model_config="./train_acoustic_dit_b.yaml",
12
+ duration_model_checkpoint="./duration_model_0120000.pt",
13
+ acoustic_model_checkpoint="./acoustic_model_0140000.pt",
14
+ speaker_id=speaker_id,
15
+ cfg_scale=cfg_scale,
16
+ num_sampling_steps=num_sampling_steps,
17
+ )
18
+ return (sample_rate, audio)
19
+
20
+
21
+ speaker_ids = [str(i) for i in range(100)]
22
+ sampling_steps = [100, 250, 500, 1000]
23
+
24
+ demo = gr.Interface(
25
+ fn=text_to_speech,
26
+ inputs=[
27
+ gr.Textbox(label="Text"),
28
+ gr.Dropdown(choices=speaker_ids, label="Speaker ID", value="0"),
29
+ gr.Slider(minimum=0, maximum=10, value=4.0, label="CFG Scale"),
30
+ gr.Dropdown(choices=sampling_steps, label="Sampling Steps", value=100),
31
+ ],
32
+ outputs=gr.Audio(label="Generated Speech"),
33
+ title="Text to Speech with Diffusion Transformer",
34
+ description="Enter text, select a speaker ID (0-99), and adjust the CFG scale to generate speech.",
35
+ flagging_options=None,
36
+ )
37
+
38
  demo.launch()
maps.json ADDED
@@ -0,0 +1,989 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "speaker_id_to_idx": {
3
+ "3003": 0,
4
+ "2204": 1,
5
+ "3307": 2,
6
+ "8080": 3,
7
+ "5935": 4,
8
+ "3922": 5,
9
+ "7982": 6,
10
+ "3638": 7,
11
+ "3032": 8,
12
+ "8699": 9,
13
+ "3274": 10,
14
+ "6189": 11,
15
+ "100": 12,
16
+ "7481": 13,
17
+ "5570": 14,
18
+ "176": 15,
19
+ "1638": 16,
20
+ "1776": 17,
21
+ "5655": 18,
22
+ "534": 19,
23
+ "2401": 20,
24
+ "2853": 21,
25
+ "5400": 22,
26
+ "7484": 23,
27
+ "5290": 24,
28
+ "8193": 25,
29
+ "1027": 26,
30
+ "7134": 27,
31
+ "7240": 28,
32
+ "5712": 29,
33
+ "339": 30,
34
+ "5126": 31,
35
+ "3869": 32,
36
+ "7594": 33,
37
+ "4899": 34,
38
+ "5914": 35,
39
+ "7515": 36,
40
+ "922": 37,
41
+ "8142": 38,
42
+ "28": 39,
43
+ "459": 40,
44
+ "114": 41,
45
+ "6927": 42,
46
+ "5012": 43,
47
+ "2481": 44,
48
+ "1923": 45,
49
+ "5618": 46,
50
+ "7995": 47,
51
+ "5054": 48,
52
+ "8474": 49,
53
+ "8527": 50,
54
+ "2299": 51,
55
+ "8498": 52,
56
+ "1271": 53,
57
+ "4356": 54,
58
+ "8066": 55,
59
+ "9022": 56,
60
+ "7139": 57,
61
+ "318": 58,
62
+ "3185": 59,
63
+ "6104": 60,
64
+ "6683": 61,
65
+ "6877": 62,
66
+ "8194": 63,
67
+ "2531": 64,
68
+ "7789": 65,
69
+ "2971": 66,
70
+ "7754": 67,
71
+ "3994": 68,
72
+ "6555": 69,
73
+ "8506": 70,
74
+ "2053": 71,
75
+ "6426": 72,
76
+ "6694": 73,
77
+ "3370": 74,
78
+ "16": 75,
79
+ "6544": 76,
80
+ "5876": 77,
81
+ "5717": 78,
82
+ "93": 79,
83
+ "7145": 80,
84
+ "3615": 81,
85
+ "8848": 82,
86
+ "3361": 83,
87
+ "1463": 84,
88
+ "3923": 85,
89
+ "4957": 86,
90
+ "501": 87,
91
+ "6458": 88,
92
+ "8152": 89,
93
+ "8410": 90,
94
+ "1265": 91,
95
+ "7910": 92,
96
+ "4145": 93,
97
+ "5538": 94,
98
+ "4039": 95,
99
+ "6965": 96,
100
+ "7314": 97,
101
+ "2960": 98,
102
+ "1392": 99,
103
+ "2573": 100,
104
+ "6550": 101,
105
+ "7833": 102,
106
+ "7959": 103,
107
+ "1226": 104,
108
+ "7874": 105,
109
+ "2769": 106,
110
+ "2156": 107,
111
+ "8011": 108,
112
+ "6701": 109,
113
+ "4595": 110,
114
+ "6099": 111,
115
+ "7926": 112,
116
+ "8183": 113,
117
+ "7247": 114,
118
+ "3368": 115,
119
+ "7460": 116,
120
+ "7752": 117,
121
+ "8718": 118,
122
+ "8222": 119,
123
+ "7949": 120,
124
+ "4945": 121,
125
+ "8225": 122,
126
+ "5093": 123,
127
+ "5319": 124,
128
+ "2404": 125,
129
+ "7766": 126,
130
+ "4438": 127,
131
+ "724": 128,
132
+ "6341": 129,
133
+ "7704": 130,
134
+ "1422": 131,
135
+ "3483": 132,
136
+ "3549": 133,
137
+ "8772": 134,
138
+ "4535": 135,
139
+ "1313": 136,
140
+ "7981": 137,
141
+ "7245": 138,
142
+ "6499": 139,
143
+ "3157": 140,
144
+ "3717": 141,
145
+ "8786": 142,
146
+ "7416": 143,
147
+ "209": 144,
148
+ "8075": 145,
149
+ "3851": 146,
150
+ "4110": 147,
151
+ "7540": 148,
152
+ "3482": 149,
153
+ "6828": 150,
154
+ "335": 151,
155
+ "203": 152,
156
+ "7717": 153,
157
+ "7313": 154,
158
+ "3540": 155,
159
+ "3380": 156,
160
+ "4243": 157,
161
+ "688": 158,
162
+ "8725": 159,
163
+ "8820": 160,
164
+ "7994": 161,
165
+ "7128": 162,
166
+ "4098": 163,
167
+ "7933": 164,
168
+ "1446": 165,
169
+ "6567": 166,
170
+ "4586": 167,
171
+ "4973": 168,
172
+ "1195": 169,
173
+ "79": 170,
174
+ "3228": 171,
175
+ "7832": 172,
176
+ "1731": 173,
177
+ "6904": 174,
178
+ "8605": 175,
179
+ "637": 176,
180
+ "5985": 177,
181
+ "7495": 178,
182
+ "7553": 179,
183
+ "6006": 180,
184
+ "7825": 181,
185
+ "7739": 182,
186
+ "242": 183,
187
+ "5909": 184,
188
+ "225": 185,
189
+ "5261": 186,
190
+ "770": 187,
191
+ "4427": 188,
192
+ "731": 189,
193
+ "4116": 190,
194
+ "64": 191,
195
+ "7188": 192,
196
+ "1752": 193,
197
+ "6080": 194,
198
+ "2045": 195,
199
+ "5246": 196,
200
+ "816": 197,
201
+ "7276": 198,
202
+ "207": 199,
203
+ "3584": 200,
204
+ "7720": 201,
205
+ "5731": 202,
206
+ "3703": 203,
207
+ "1806": 204,
208
+ "6038": 205,
209
+ "1913": 206,
210
+ "6139": 207,
211
+ "7229": 208,
212
+ "2004": 209,
213
+ "3119": 210,
214
+ "8113": 211,
215
+ "8329": 212,
216
+ "6497": 213,
217
+ "8684": 214,
218
+ "6120": 215,
219
+ "6233": 216,
220
+ "7286": 217,
221
+ "5489": 218,
222
+ "2056": 219,
223
+ "2562": 220,
224
+ "3046": 221,
225
+ "2256": 222,
226
+ "7120": 223,
227
+ "6286": 224,
228
+ "5039": 225,
229
+ "98": 226,
230
+ "2827": 227,
231
+ "5660": 228,
232
+ "3914": 229,
233
+ "4148": 230,
234
+ "192": 231,
235
+ "3733": 232,
236
+ "4331": 233,
237
+ "258": 234,
238
+ "6308": 235,
239
+ "4363": 236,
240
+ "8776": 237,
241
+ "5635": 238,
242
+ "2074": 239,
243
+ "3105": 240,
244
+ "5007": 241,
245
+ "2592": 242,
246
+ "2774": 243,
247
+ "764": 244,
248
+ "7525": 245,
249
+ "2598": 246,
250
+ "5740": 247,
251
+ "7868": 248,
252
+ "8028": 249,
253
+ "3816": 250,
254
+ "3118": 251,
255
+ "8118": 252,
256
+ "3009": 253,
257
+ "7956": 254,
258
+ "5062": 255,
259
+ "3728": 256,
260
+ "1649": 257,
261
+ "2517": 258,
262
+ "56": 259,
263
+ "8050": 260,
264
+ "1678": 261,
265
+ "6300": 262,
266
+ "274": 263,
267
+ "7881": 264,
268
+ "3927": 265,
269
+ "6643": 266,
270
+ "1289": 267,
271
+ "8006": 268,
272
+ "175": 269,
273
+ "1777": 270,
274
+ "781": 271,
275
+ "6788": 272,
276
+ "699": 273,
277
+ "5583": 274,
278
+ "4138": 275,
279
+ "511": 276,
280
+ "6388": 277,
281
+ "2427": 278,
282
+ "8713": 279,
283
+ "7434": 280,
284
+ "288": 281,
285
+ "5293": 282,
286
+ "1046": 283,
287
+ "1482": 284,
288
+ "8825": 285,
289
+ "510": 286,
290
+ "8195": 287,
291
+ "5984": 288,
292
+ "409": 289,
293
+ "8643": 290,
294
+ "4744": 291,
295
+ "6371": 292,
296
+ "8591": 293,
297
+ "4434": 294,
298
+ "54": 295,
299
+ "126": 296,
300
+ "7258": 297,
301
+ "4289": 298,
302
+ "373": 299,
303
+ "7569": 300,
304
+ "2238": 301,
305
+ "5029": 302,
306
+ "8887": 303,
307
+ "7498": 304,
308
+ "7705": 305,
309
+ "7816": 306,
310
+ "7828": 307,
311
+ "2240": 308,
312
+ "1472": 309,
313
+ "2137": 310,
314
+ "7437": 311,
315
+ "4592": 312,
316
+ "718": 313,
317
+ "3972": 314,
318
+ "6032": 315,
319
+ "7867": 316,
320
+ "2775": 317,
321
+ "2638": 318,
322
+ "5239": 319,
323
+ "6696": 320,
324
+ "6763": 321,
325
+ "1903": 322,
326
+ "6098": 323,
327
+ "6937": 324,
328
+ "3230": 325,
329
+ "2589": 326,
330
+ "4967": 327,
331
+ "1283": 328,
332
+ "8824": 329,
333
+ "8163": 330,
334
+ "1335": 331,
335
+ "1165": 332,
336
+ "8534": 333,
337
+ "1789": 334,
338
+ "6538": 335,
339
+ "6395": 336,
340
+ "5448": 337,
341
+ "597": 338,
342
+ "3989": 339,
343
+ "6373": 340,
344
+ "6519": 341,
345
+ "6637": 342,
346
+ "6406": 343,
347
+ "5190": 344,
348
+ "2787": 345,
349
+ "3945": 346,
350
+ "8758": 347,
351
+ "359": 348,
352
+ "6269": 349,
353
+ "5684": 350,
354
+ "6981": 351,
355
+ "2582": 352,
356
+ "1093": 353,
357
+ "1826": 354,
358
+ "6317": 355,
359
+ "7000": 356,
360
+ "4860": 357,
361
+ "8388": 358,
362
+ "5724": 359,
363
+ "1121": 360,
364
+ "8404": 361,
365
+ "6575": 362,
366
+ "5975": 363,
367
+ "2494": 364,
368
+ "561": 365,
369
+ "6294": 366,
370
+ "7294": 367,
371
+ "6918": 368,
372
+ "5519": 369,
373
+ "2570": 370,
374
+ "5206": 371,
375
+ "2010": 372,
376
+ "8008": 373,
377
+ "2512": 374,
378
+ "2882": 375,
379
+ "3977": 376,
380
+ "6054": 377,
381
+ "1018": 378,
382
+ "1079": 379,
383
+ "3221": 380,
384
+ "3083": 381,
385
+ "2532": 382,
386
+ "2741": 383,
387
+ "3357": 384,
388
+ "1028": 385,
389
+ "1401": 386,
390
+ "806": 387,
391
+ "1054": 388,
392
+ "6188": 389,
393
+ "5622": 390,
394
+ "101": 391,
395
+ "6014": 392,
396
+ "4222": 393,
397
+ "5656": 394,
398
+ "5092": 395,
399
+ "303": 396,
400
+ "4854": 397,
401
+ "4519": 398,
402
+ "716": 399,
403
+ "3289": 400,
404
+ "2060": 401,
405
+ "850": 402,
406
+ "8138": 403,
407
+ "5389": 404,
408
+ "7140": 405,
409
+ "369": 406,
410
+ "667": 407,
411
+ "188": 408,
412
+ "1112": 409,
413
+ "1509": 410,
414
+ "2758": 411,
415
+ "4837": 412,
416
+ "2230": 413,
417
+ "2388": 414,
418
+ "205": 415,
419
+ "984": 416,
420
+ "1535": 417,
421
+ "500": 418,
422
+ "4856": 419,
423
+ "6510": 420,
424
+ "7318": 421,
425
+ "5809": 422,
426
+ "3224": 423,
427
+ "835": 424,
428
+ "7117": 425,
429
+ "337": 426,
430
+ "5802": 427,
431
+ "1348": 428,
432
+ "480": 429,
433
+ "5189": 430,
434
+ "3864": 431,
435
+ "2999": 432,
436
+ "4257": 433,
437
+ "7967": 434,
438
+ "5918": 435,
439
+ "8855": 436,
440
+ "598": 437,
441
+ "5333": 438,
442
+ "2364": 439,
443
+ "6505": 440,
444
+ "1851": 441,
445
+ "4848": 442,
446
+ "882": 443,
447
+ "580": 444,
448
+ "2269": 445,
449
+ "3967": 446,
450
+ "8573": 447,
451
+ "9023": 448,
452
+ "512": 449,
453
+ "1349": 450,
454
+ "4629": 451,
455
+ "1382": 452,
456
+ "6206": 453,
457
+ "7783": 454,
458
+ "3448": 455,
459
+ "6378": 456,
460
+ "3905": 457,
461
+ "6167": 458,
462
+ "7384": 459,
463
+ "666": 460,
464
+ "949": 461,
465
+ "6924": 462,
466
+ "4290": 463,
467
+ "4490": 464,
468
+ "8401": 465,
469
+ "7383": 466,
470
+ "4598": 467,
471
+ "6446": 468,
472
+ "329": 469,
473
+ "7809": 470,
474
+ "231": 471,
475
+ "30": 472,
476
+ "6215": 473,
477
+ "6686": 474,
478
+ "3654": 475,
479
+ "7939": 476,
480
+ "5810": 477,
481
+ "6574": 478,
482
+ "14": 479,
483
+ "7478": 480,
484
+ "8494": 481,
485
+ "7732": 482,
486
+ "7030": 483,
487
+ "8635": 484,
488
+ "296": 485,
489
+ "5968": 486,
490
+ "2012": 487,
491
+ "5401": 488,
492
+ "7316": 489,
493
+ "3258": 490,
494
+ "4839": 491,
495
+ "707": 492,
496
+ "1874": 493,
497
+ "3521": 494,
498
+ "7932": 495,
499
+ "6865": 496,
500
+ "7285": 497,
501
+ "2992": 498,
502
+ "8421": 499,
503
+ "4425": 500,
504
+ "3008": 501,
505
+ "6115": 502,
506
+ "7688": 503,
507
+ "968": 504,
508
+ "1724": 505,
509
+ "8266": 506,
510
+ "2194": 507,
511
+ "7520": 508,
512
+ "5604": 509,
513
+ "1571": 510,
514
+ "6160": 511,
515
+ "2294": 512,
516
+ "7647": 513,
517
+ "4064": 514,
518
+ "3171": 515,
519
+ "548": 516,
520
+ "2673": 517,
521
+ "8119": 518,
522
+ "6492": 519,
523
+ "8396": 520,
524
+ "3446": 521,
525
+ "7335": 522,
526
+ "7169": 523,
527
+ "224": 524,
528
+ "8097": 525,
529
+ "3852": 526,
530
+ "920": 527,
531
+ "2368": 528,
532
+ "664": 529,
533
+ "4800": 530,
534
+ "7241": 531,
535
+ "278": 532,
536
+ "1769": 533,
537
+ "4731": 534,
538
+ "953": 535,
539
+ "3825": 536,
540
+ "8742": 537,
541
+ "2688": 538,
542
+ "1603": 539,
543
+ "3645": 540,
544
+ "5133": 541,
545
+ "4278": 542,
546
+ "1513": 543,
547
+ "7938": 544,
548
+ "4054": 545,
549
+ "272": 546,
550
+ "7802": 547,
551
+ "1859": 548,
552
+ "7069": 549,
553
+ "8687": 550,
554
+ "4719": 551,
555
+ "2751": 552,
556
+ "7085": 553,
557
+ "5266": 554,
558
+ "1050": 555,
559
+ "6060": 556,
560
+ "5386": 557,
561
+ "6782": 558,
562
+ "5868": 559,
563
+ "353": 560,
564
+ "227": 561,
565
+ "204": 562,
566
+ "479": 563,
567
+ "5002": 564,
568
+ "5154": 565,
569
+ "1025": 566,
570
+ "7051": 567,
571
+ "7095": 568,
572
+ "6235": 569,
573
+ "3294": 570,
574
+ "6993": 571,
575
+ "3835": 572,
576
+ "2393": 573,
577
+ "1779": 574,
578
+ "948": 575,
579
+ "8464": 576,
580
+ "6359": 577,
581
+ "1343": 578,
582
+ "8791": 579,
583
+ "8176": 580,
584
+ "1748": 581,
585
+ "1825": 582,
586
+ "249": 583,
587
+ "5337": 584,
588
+ "7909": 585,
589
+ "8677": 586,
590
+ "22": 587,
591
+ "3082": 588,
592
+ "119": 589,
593
+ "6727": 590,
594
+ "3513": 591,
595
+ "7991": 592,
596
+ "4734": 593,
597
+ "5606": 594,
598
+ "4260": 595,
599
+ "4226": 596,
600
+ "7398": 597,
601
+ "3551": 598,
602
+ "4806": 599,
603
+ "868": 600,
604
+ "5883": 601,
605
+ "7011": 602,
606
+ "1053": 603,
607
+ "112": 604,
608
+ "1379": 605,
609
+ "166": 606,
610
+ "1212": 607,
611
+ "594": 608,
612
+ "589": 609,
613
+ "454": 610,
614
+ "925": 611,
615
+ "497": 612,
616
+ "1322": 613,
617
+ "7777": 614,
618
+ "70": 615,
619
+ "7297": 616,
620
+ "159": 617,
621
+ "899": 618,
622
+ "7061": 619,
623
+ "2618": 620,
624
+ "7555": 621,
625
+ "8459": 622,
626
+ "1365": 623,
627
+ "8879": 624,
628
+ "2113": 625,
629
+ "1066": 626,
630
+ "1445": 627,
631
+ "2229": 628,
632
+ "8057": 629,
633
+ "345": 630,
634
+ "612": 631,
635
+ "1958": 632,
636
+ "8545": 633,
637
+ "1448": 634,
638
+ "3328": 635,
639
+ "2929": 636,
640
+ "7962": 637,
641
+ "1387": 638,
642
+ "6494": 639,
643
+ "3876": 640,
644
+ "380": 641,
645
+ "5723": 642,
646
+ "6553": 643,
647
+ "3114": 644,
648
+ "398": 645,
649
+ "1885": 646,
650
+ "2709": 647,
651
+ "2920": 648,
652
+ "576": 649,
653
+ "1160": 650,
654
+ "246": 651,
655
+ "4335": 652,
656
+ "7558": 653,
657
+ "5115": 654,
658
+ "4133": 655,
659
+ "1641": 656,
660
+ "6956": 657,
661
+ "5139": 658,
662
+ "1264": 659,
663
+ "1259": 660,
664
+ "1705": 661,
665
+ "1961": 662,
666
+ "1012": 663,
667
+ "4057": 664,
668
+ "6895": 665,
669
+ "8347": 666,
670
+ "1337": 667,
671
+ "7395": 668,
672
+ "5672": 669,
673
+ "8479": 670,
674
+ "8190": 671,
675
+ "4381": 672,
676
+ "8771": 673,
677
+ "1060": 674,
678
+ "6258": 675,
679
+ "2654": 676,
680
+ "5123": 677,
681
+ "5304": 678,
682
+ "2790": 679,
683
+ "2254": 680,
684
+ "3493": 681,
685
+ "3180": 682,
686
+ "6352": 683,
687
+ "1987": 684,
688
+ "3094": 685,
689
+ "9026": 686,
690
+ "1222": 687,
691
+ "4013": 688,
692
+ "340": 689,
693
+ "2785": 690,
694
+ "3072": 691,
695
+ "6339": 692,
696
+ "836": 693,
697
+ "3330": 694,
698
+ "475": 695,
699
+ "7945": 696,
700
+ "6082": 697,
701
+ "4111": 698,
702
+ "8705": 699,
703
+ "3781": 700,
704
+ "4010": 701,
705
+ "5746": 702,
706
+ "1827": 703,
707
+ "476": 704,
708
+ "7538": 705,
709
+ "8875": 706,
710
+ "2149": 707,
711
+ "2272": 708,
712
+ "6509": 709,
713
+ "2127": 710,
714
+ "7730": 711,
715
+ "2823": 712,
716
+ "3490": 713,
717
+ "1241": 714,
718
+ "698": 715,
719
+ "1460": 716,
720
+ "6620": 717,
721
+ "1845": 718,
722
+ "3889": 719,
723
+ "1383": 720,
724
+ "3630": 721,
725
+ "2473": 722,
726
+ "2319": 723,
727
+ "6288": 724,
728
+ "7837": 725,
729
+ "711": 726,
730
+ "5186": 727,
731
+ "4238": 728,
732
+ "1944": 729,
733
+ "3379": 730,
734
+ "2577": 731,
735
+ "451": 732,
736
+ "3389": 733,
737
+ "5242": 734,
738
+ "1849": 735,
739
+ "2110": 736,
740
+ "5147": 737,
741
+ "4807": 738,
742
+ "7957": 739,
743
+ "581": 740,
744
+ "4108": 741,
745
+ "1323": 742,
746
+ "1556": 743,
747
+ "8619": 744,
748
+ "783": 745,
749
+ "957": 746,
750
+ "3979": 747,
751
+ "1052": 748,
752
+ "439": 749,
753
+ "115": 750,
754
+ "3866": 751,
755
+ "6673": 752,
756
+ "1943": 753,
757
+ "6037": 754,
758
+ "6157": 755,
759
+ "1668": 756,
760
+ "240": 757,
761
+ "596": 758,
762
+ "1974": 759,
763
+ "8592": 760,
764
+ "1061": 761,
765
+ "81": 762,
766
+ "639": 763,
767
+ "1487": 764,
768
+ "954": 765,
769
+ "8300": 766,
770
+ "1425": 767,
771
+ "2162": 768,
772
+ "7090": 769,
773
+ "4733": 770,
774
+ "492": 771,
775
+ "2061": 772,
776
+ "7733": 773,
777
+ "1413": 774,
778
+ "208": 775,
779
+ "5727": 776,
780
+ "2146": 777,
781
+ "6518": 778,
782
+ "4590": 779,
783
+ "216": 780,
784
+ "2652": 781,
785
+ "5776": 782,
786
+ "6330": 783,
787
+ "5588": 784,
788
+ "2285": 785,
789
+ "7126": 786,
790
+ "5513": 787,
791
+ "979": 788,
792
+ "5940": 789,
793
+ "1740": 790,
794
+ "4044": 791,
795
+ "8228": 792,
796
+ "543": 793,
797
+ "17": 794,
798
+ "4358": 795,
799
+ "5157": 796,
800
+ "8575": 797,
801
+ "1417": 798,
802
+ "1224": 799,
803
+ "606": 800,
804
+ "2628": 801,
805
+ "3546": 802,
806
+ "5063": 803,
807
+ "2201": 804,
808
+ "1311": 805,
809
+ "593": 806,
810
+ "1547": 807,
811
+ "1629": 808,
812
+ "4152": 809,
813
+ "1914": 810,
814
+ "7445": 811,
815
+ "4246": 812,
816
+ "1607": 813,
817
+ "559": 814,
818
+ "1801": 815,
819
+ "122": 816,
820
+ "1639": 817,
821
+ "663": 818,
822
+ "323": 819,
823
+ "2411": 820,
824
+ "1182": 821,
825
+ "3025": 822,
826
+ "5767": 823,
827
+ "2397": 824,
828
+ "2093": 825,
829
+ "4495": 826,
830
+ "708": 827,
831
+ "986": 828,
832
+ "154": 829,
833
+ "1498": 830,
834
+ "829": 831,
835
+ "3792": 832,
836
+ "549": 833,
837
+ "815": 834,
838
+ "362": 835,
839
+ "1933": 836,
840
+ "3790": 837,
841
+ "1100": 838,
842
+ "1811": 839,
843
+ "2812": 840,
844
+ "55": 841,
845
+ "636": 842,
846
+ "8490": 843,
847
+ "2499": 844,
848
+ "2167": 845,
849
+ "803": 846,
850
+ "3537": 847,
851
+ "4236": 848,
852
+ "90": 849,
853
+ "217": 850,
854
+ "3340": 851,
855
+ "7475": 852,
856
+ "1336": 853,
857
+ "2581": 854,
858
+ "1316": 855,
859
+ "7665": 856,
860
+ "1001": 857,
861
+ "3001": 858,
862
+ "3738": 859,
863
+ "472": 860,
864
+ "1473": 861,
865
+ "3215": 862,
866
+ "487": 863,
867
+ "1290": 864,
868
+ "7339": 865,
869
+ "3686": 866,
870
+ "7342": 867,
871
+ "6075": 868,
872
+ "210": 869,
873
+ "1552": 870,
874
+ "38": 871,
875
+ "7657": 872,
876
+ "2816": 873,
877
+ "1734": 874,
878
+ "6690": 875,
879
+ "1390": 876,
880
+ "3070": 877,
881
+ "8722": 878,
882
+ "7518": 879,
883
+ "1754": 880,
884
+ "408": 881,
885
+ "1645": 882,
886
+ "2696": 883,
887
+ "830": 884,
888
+ "4846": 885,
889
+ "1175": 886,
890
+ "2498": 887,
891
+ "434": 888,
892
+ "1058": 889,
893
+ "2674": 890,
894
+ "583": 891,
895
+ "1296": 892,
896
+ "4681": 893,
897
+ "4926": 894,
898
+ "820": 895,
899
+ "6119": 896,
900
+ "1536": 897,
901
+ "3187": 898,
902
+ "157": 899,
903
+ "2039": 900,
904
+ "4433": 901,
905
+ "4071": 902,
906
+ "2085": 903
907
+ },
908
+ "phone_to_idx": {
909
+ "AA0": 0,
910
+ "AA1": 1,
911
+ "AA2": 2,
912
+ "AE0": 3,
913
+ "AE1": 4,
914
+ "AE2": 5,
915
+ "AH0": 6,
916
+ "AH1": 7,
917
+ "AH2": 8,
918
+ "AO0": 9,
919
+ "AO1": 10,
920
+ "AO2": 11,
921
+ "AW0": 12,
922
+ "AW1": 13,
923
+ "AW2": 14,
924
+ "AY0": 15,
925
+ "AY1": 16,
926
+ "AY2": 17,
927
+ "B": 18,
928
+ "CH": 19,
929
+ "D": 20,
930
+ "DH": 21,
931
+ "EH0": 22,
932
+ "EH1": 23,
933
+ "EH2": 24,
934
+ "EMPTY": 25,
935
+ "ER0": 26,
936
+ "ER1": 27,
937
+ "ER2": 28,
938
+ "EY0": 29,
939
+ "EY1": 30,
940
+ "EY2": 31,
941
+ "F": 32,
942
+ "G": 33,
943
+ "HH": 34,
944
+ "IH0": 35,
945
+ "IH1": 36,
946
+ "IH2": 37,
947
+ "IY0": 38,
948
+ "IY1": 39,
949
+ "IY2": 40,
950
+ "JH": 41,
951
+ "K": 42,
952
+ "L": 43,
953
+ "M": 44,
954
+ "N": 45,
955
+ "NG": 46,
956
+ "OW0": 47,
957
+ "OW1": 48,
958
+ "OW2": 49,
959
+ "OY0": 50,
960
+ "OY1": 51,
961
+ "OY2": 52,
962
+ "P": 53,
963
+ "R": 54,
964
+ "S": 55,
965
+ "SH": 56,
966
+ "SPN": 57,
967
+ "T": 58,
968
+ "TH": 59,
969
+ "UH0": 60,
970
+ "UH1": 61,
971
+ "UH2": 62,
972
+ "UW0": 63,
973
+ "UW1": 64,
974
+ "UW2": 65,
975
+ "V": 66,
976
+ "W": 67,
977
+ "Y": 68,
978
+ "Z": 69,
979
+ "ZH": 70,
980
+ "PAD": 71
981
+ },
982
+ "phone_kind_to_idx": {
983
+ "EMPTY": 0,
984
+ "WORD": 1,
985
+ "START": 2,
986
+ "END": 3,
987
+ "MIDDLE": 4
988
+ }
989
+ }
models.py ADDED
@@ -0,0 +1,478 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # References:
8
+ # GLIDE: https://github.com/openai/glide-text2im
9
+ # MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
10
+ # --------------------------------------------------------
11
+
12
+ import math
13
+
14
+ import numpy as np
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ from torch.nn.attention.flex_attention import BlockMask, flex_attention
19
+
20
+
21
+ class Mlp(nn.Module):
22
+ """MLP as used in Vision Transformer, MLP-Mixer and related networks"""
23
+
24
+ def __init__(
25
+ self,
26
+ in_features,
27
+ hidden_features=None,
28
+ out_features=None,
29
+ act_layer=nn.SiLU,
30
+ norm_layer=None,
31
+ bias=True,
32
+ ):
33
+ super().__init__()
34
+ out_features = out_features or in_features
35
+ hidden_features = hidden_features or in_features
36
+ linear_layer = nn.Linear
37
+
38
+ self.fc1 = linear_layer(in_features, hidden_features, bias=bias)
39
+ self.act = act_layer()
40
+ self.norm = (
41
+ norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
42
+ )
43
+ self.fc2 = linear_layer(hidden_features, out_features, bias=bias)
44
+
45
+ def forward(self, x):
46
+ x = self.fc1(x)
47
+ x = self.act(x)
48
+ x = self.norm(x)
49
+ x = self.fc2(x)
50
+ return x
51
+
52
+
53
+ class RMSNorm(torch.nn.Module):
54
+ def __init__(self, dim: int, eps: float = 1e-5):
55
+ super().__init__()
56
+ self.eps = eps
57
+ self.weight = nn.Parameter(torch.ones(dim))
58
+
59
+ def _norm(self, x):
60
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
61
+
62
+ def forward(self, x):
63
+ output = self._norm(x.float()).type_as(x)
64
+ return output * self.weight
65
+
66
+
67
+ class Attention(nn.Module):
68
+ def __init__(
69
+ self,
70
+ dim: int,
71
+ num_heads: int = 8,
72
+ qkv_bias: bool = False,
73
+ ) -> None:
74
+ super().__init__()
75
+ assert dim % num_heads == 0, "dim should be divisible by num_heads"
76
+ self.num_heads = num_heads
77
+ self.head_dim = dim // num_heads
78
+ self.scale = self.head_dim**-0.5
79
+
80
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
81
+ self.proj = nn.Linear(dim, dim)
82
+ self.q_norm = RMSNorm(self.head_dim, eps=1e-5)
83
+ self.k_norm = RMSNorm(self.head_dim, eps=1e-5)
84
+
85
+ def forward(self, x: torch.Tensor, attn_mask=None) -> torch.Tensor:
86
+ B, N, C = x.shape
87
+ qkv = (
88
+ self.qkv(x)
89
+ .reshape(B, N, 3, self.num_heads, self.head_dim)
90
+ .permute(2, 0, 3, 1, 4)
91
+ )
92
+ q, k, v = qkv.unbind(0)
93
+
94
+ if isinstance(attn_mask, torch.Tensor) or attn_mask is None:
95
+ q = self.q_norm(q)
96
+ k = self.k_norm(k)
97
+ # v = v
98
+ x = F.scaled_dot_product_attention(
99
+ q,
100
+ k,
101
+ v,
102
+ attn_mask=attn_mask,
103
+ )
104
+ elif isinstance(attn_mask, BlockMask):
105
+ with torch.autocast(enabled=False, device_type="cuda"):
106
+ q = self.q_norm(q).half()
107
+ k = self.k_norm(k).half()
108
+ v = v.half()
109
+ x = flex_attention(q, k, v, block_mask=attn_mask)
110
+ x = x.transpose(1, 2).reshape(B, N, C)
111
+ x = self.proj(x)
112
+ return x
113
+
114
+
115
+ def modulate(x, shift, scale):
116
+ return x * (1 + scale) + shift
117
+
118
+
119
+ #################################################################################
120
+ # Embedding Layers for Timesteps and Class Labels #
121
+ #################################################################################
122
+
123
+
124
+ class TimestepEmbedder(nn.Module):
125
+ """
126
+ Embeds scalar timesteps into vector representations.
127
+ """
128
+
129
+ def __init__(self, hidden_size, frequency_embedding_size=256):
130
+ super().__init__()
131
+ self.mlp = nn.Sequential(
132
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
133
+ nn.SiLU(),
134
+ nn.Linear(hidden_size, hidden_size, bias=True),
135
+ )
136
+ self.frequency_embedding_size = frequency_embedding_size
137
+
138
+ @staticmethod
139
+ def timestep_embedding(t, dim, max_period=10000):
140
+ """
141
+ Create sinusoidal timestep embeddings.
142
+ :param t: a 1-D Tensor of N indices, one per batch element.
143
+ These may be fractional.
144
+ :param dim: the dimension of the output.
145
+ :param max_period: controls the minimum frequency of the embeddings.
146
+ :return: an (N, D) Tensor of positional embeddings.
147
+ """
148
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
149
+ half = dim // 2
150
+ freqs = torch.exp(
151
+ -math.log(max_period)
152
+ * torch.arange(start=0, end=half, dtype=torch.float32)
153
+ / half
154
+ ).to(device=t.device)
155
+ args = t[:, None].float() * freqs[None]
156
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
157
+ if dim % 2:
158
+ embedding = torch.cat(
159
+ [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
160
+ )
161
+ return embedding
162
+
163
+ def forward(self, t):
164
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
165
+ t_emb = self.mlp(t_freq)
166
+ return t_emb
167
+
168
+
169
+ class LabelEmbedder(nn.Module):
170
+ """
171
+ Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
172
+ """
173
+
174
+ def __init__(self, num_classes, hidden_size, dropout_prob):
175
+ super().__init__()
176
+ # use last class as unconditional value
177
+ use_cfg_embedding = dropout_prob > 0
178
+ if use_cfg_embedding:
179
+ self.unconditional_value = num_classes - 1
180
+ self.speaker_id_table = nn.Embedding(num_classes, hidden_size)
181
+ self.phone_table = nn.Embedding(num_classes, hidden_size)
182
+ self.phone_kind_table = nn.Embedding(num_classes, hidden_size)
183
+ self.num_classes = num_classes
184
+ self.dropout_prob = dropout_prob
185
+
186
+ def token_drop(self, speaker_id, phone, phone_kind, force_drop_ids=None):
187
+ """
188
+ Drops labels to enable classifier-free guidance.
189
+ """
190
+ if force_drop_ids is None:
191
+ drop_ids = (
192
+ torch.rand(speaker_id.shape[0], device=speaker_id.device)
193
+ < self.dropout_prob
194
+ )
195
+ else:
196
+ drop_ids = force_drop_ids == 1
197
+ speaker_id = torch.where(
198
+ drop_ids[:, None], self.unconditional_value, speaker_id
199
+ )
200
+ phone = torch.where(drop_ids[:, None], self.unconditional_value, phone)
201
+ phone_kind = torch.where(
202
+ drop_ids[:, None], self.unconditional_value, phone_kind
203
+ )
204
+ return speaker_id, phone, phone_kind
205
+
206
+ def forward(self, speaker_id, phone, phone_kind, train, force_drop_ids=None):
207
+ use_dropout = self.dropout_prob > 0
208
+ if (train and use_dropout) or (force_drop_ids is not None):
209
+ speaker_id, phone, phone_kind = self.token_drop(
210
+ speaker_id, phone, phone_kind, force_drop_ids
211
+ )
212
+ speaker_id_embeddings = self.speaker_id_table(speaker_id)
213
+ phone_embeddings = self.phone_table(phone)
214
+ phone_kind_embeddings = self.phone_kind_table(phone_kind)
215
+ return speaker_id_embeddings, phone_embeddings, phone_kind_embeddings
216
+
217
+
218
+ #################################################################################
219
+ # Core DiT Model #
220
+ #################################################################################
221
+
222
+
223
+ class DiTBlock(nn.Module):
224
+ """
225
+ A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
226
+ """
227
+
228
+ def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
229
+ super().__init__()
230
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
231
+ self.attn = Attention(
232
+ hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs
233
+ )
234
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
235
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
236
+ self.mlp = Mlp(
237
+ in_features=hidden_size,
238
+ hidden_features=mlp_hidden_dim,
239
+ act_layer=nn.SiLU,
240
+ )
241
+ self.adaLN_modulation = nn.Sequential(
242
+ nn.SiLU(),
243
+ nn.Linear(hidden_size, 6 * hidden_size, bias=True),
244
+ )
245
+
246
+ def forward(self, x, c, attn_mask=None):
247
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
248
+ self.adaLN_modulation(c).chunk(6, dim=-1)
249
+ )
250
+ x = x + gate_msa * self.attn(
251
+ modulate(self.norm1(x), shift_msa, scale_msa), attn_mask=attn_mask
252
+ )
253
+ x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
254
+ return x
255
+
256
+
257
+ class FinalLayer(nn.Module):
258
+ """
259
+ The final layer of DiT.
260
+ """
261
+
262
+ def __init__(self, hidden_size, patch_size, out_channels):
263
+ super().__init__()
264
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
265
+ self.linear = nn.Linear(
266
+ hidden_size, patch_size * patch_size * out_channels, bias=True
267
+ )
268
+ self.adaLN_modulation = nn.Sequential(
269
+ nn.SiLU(),
270
+ nn.Linear(hidden_size, 2 * hidden_size, bias=True),
271
+ )
272
+
273
+ def forward(self, x, c):
274
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
275
+ x = modulate(self.norm_final(x), shift, scale)
276
+ x = self.linear(x)
277
+ return x
278
+
279
+
280
+ class DiT(nn.Module):
281
+ """
282
+ Diffusion model with a Transformer backbone.
283
+ """
284
+
285
+ def __init__(
286
+ self,
287
+ input_size=256,
288
+ in_channels=1024,
289
+ hidden_size=1152,
290
+ depth=28,
291
+ num_heads=16,
292
+ mlp_ratio=4.0,
293
+ class_dropout_prob=0.1,
294
+ learn_sigma=True,
295
+ embedding_vocab_size=1024,
296
+ ):
297
+ super().__init__()
298
+ self.input_size = input_size
299
+ self.learn_sigma = learn_sigma
300
+ self.in_channels = in_channels
301
+ self.hidden_size = hidden_size
302
+ self.out_channels = in_channels * 2 if learn_sigma else in_channels
303
+ self.num_heads = num_heads
304
+
305
+ self.x_embedder = nn.Linear(in_channels, hidden_size, bias=True)
306
+ self.t_embedder = TimestepEmbedder(hidden_size)
307
+ self.y_embedder = LabelEmbedder(
308
+ embedding_vocab_size, hidden_size, class_dropout_prob
309
+ )
310
+ # Will use fixed sin-cos embedding:
311
+ self.register_buffer("pos_embed", torch.zeros(1, self.input_size, hidden_size))
312
+
313
+ self.blocks = nn.ModuleList(
314
+ [
315
+ DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio)
316
+ for _ in range(depth)
317
+ ]
318
+ )
319
+ self.final_layer = FinalLayer(hidden_size, 1, self.out_channels)
320
+ self.initialize_weights()
321
+
322
+ def initialize_weights(self):
323
+ # Initialize transformer layers:
324
+ def _basic_init(module):
325
+ if isinstance(module, nn.Linear):
326
+ torch.nn.init.xavier_uniform_(module.weight)
327
+ if module.bias is not None:
328
+ nn.init.constant_(module.bias, 0)
329
+
330
+ self.apply(_basic_init)
331
+
332
+ # Initialize (and freeze) pos_embed by sin-cos embedding:
333
+ pos_embed = get_1d_sincos_pos_embed(self.pos_embed.shape[-1], self.input_size)
334
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
335
+
336
+ # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
337
+ w = self.x_embedder.weight.data
338
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
339
+ nn.init.constant_(self.x_embedder.bias, 0)
340
+
341
+ # Initialize label embedding table:
342
+ scale = 1.0 / math.sqrt(self.hidden_size)
343
+ nn.init.trunc_normal_(self.y_embedder.speaker_id_table.weight, std=scale)
344
+ nn.init.trunc_normal_(self.y_embedder.phone_table.weight, std=scale)
345
+
346
+ # Initialize timestep embedding MLP:
347
+ nn.init.trunc_normal_(self.t_embedder.mlp[0].weight, std=scale)
348
+ nn.init.trunc_normal_(self.t_embedder.mlp[2].weight, std=scale)
349
+
350
+ # Zero-out adaLN modulation layers in DiT blocks:
351
+ for block in self.blocks:
352
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
353
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
354
+
355
+ # Zero-out output layers:
356
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
357
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
358
+ nn.init.constant_(self.final_layer.linear.weight, 0)
359
+ nn.init.constant_(self.final_layer.linear.bias, 0)
360
+
361
+ def forward(self, x, t, speaker_id, phone, phone_kind, attn_mask=None):
362
+ """
363
+ Forward pass of DiT.
364
+ x: (N, C, L) tensor of spatial inputs
365
+ t: (N,) tensor of diffusion timesteps
366
+ speaker_id: (N,) tensor of speaker IDs
367
+ phone: (N, L) tensor of phone labels
368
+ phone_kind: (N, L) tensor of phone kinds
369
+ """
370
+ # (N, D), (N, L, D)
371
+ speaker_id_embedding, phone_embedding, phone_kind_embedding = self.y_embedder(
372
+ speaker_id, phone, phone_kind, self.training
373
+ )
374
+ t = self.t_embedder(t) # (N, D)
375
+ c = t # (N, D)
376
+ c = (
377
+ c[:, None, :]
378
+ + speaker_id_embedding
379
+ + phone_embedding
380
+ + phone_kind_embedding
381
+ ) # (N, L, D)
382
+
383
+ x = x.transpose(-1, -2) # Swap last two dimensions
384
+ x = self.x_embedder(x) + self.pos_embed[:, : x.shape[1], :] # (N, L, D)
385
+ for block in self.blocks:
386
+ x = block(x, c, attn_mask=attn_mask) # (N, L, D)
387
+ x = self.final_layer(x, c) # (N, L, 2 * out_channels)
388
+ x = x.transpose(-1, -2) # Swap last two dimensions
389
+ return x
390
+
391
+ def forward_with_cfg(
392
+ self, x, t, speaker_id, phone, phone_kind, cfg_scale, attn_mask=None
393
+ ):
394
+ """
395
+ Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance.
396
+ """
397
+ # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
398
+ half = x[: len(x) // 2]
399
+ combined = torch.cat([half, half], dim=0)
400
+ model_out = self.forward(
401
+ combined, t, speaker_id, phone, phone_kind, attn_mask=attn_mask
402
+ )
403
+ # For exact reproducibility reasons, we apply classifier-free guidance on only
404
+ # three channels by default. The standard approach to cfg applies it to all channels.
405
+ # This can be done by uncommenting the following line and commenting-out the line following that.
406
+ eps, rest = model_out[:, : self.in_channels], model_out[:, self.in_channels :]
407
+ # eps, rest = model_out[:, :3], model_out[:, 3:]
408
+ cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
409
+ half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
410
+ eps = torch.cat([half_eps, half_eps], dim=0)
411
+ return torch.cat([eps, rest], dim=1)
412
+
413
+
414
+ #################################################################################
415
+ # Sine/Cosine Positional Embedding Functions #
416
+ #################################################################################
417
+ # https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
418
+
419
+
420
+ def get_1d_sincos_pos_embed(embed_dim, length, cls_token=False, extra_tokens=0):
421
+ """
422
+ length: int of the length
423
+ return:
424
+ pos_embed: [length, embed_dim] or [1+length, embed_dim] (w/ or w/o cls_token)
425
+ """
426
+ grid = np.arange(length, dtype=np.float32)
427
+ pos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim, grid)
428
+ if cls_token and extra_tokens > 0:
429
+ pos_embed = np.concatenate(
430
+ [np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0
431
+ )
432
+ return pos_embed
433
+
434
+
435
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
436
+ """
437
+ embed_dim: output dimension for each position
438
+ pos: a list of positions to be encoded: size (M,)
439
+ out: (M, D)
440
+ """
441
+ assert embed_dim % 2 == 0
442
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
443
+ omega /= embed_dim / 2.0
444
+ omega = 1.0 / 10000**omega # (D/2,)
445
+
446
+ pos = pos.reshape(-1) # (M,)
447
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
448
+
449
+ emb_sin = np.sin(out) # (M, D/2)
450
+ emb_cos = np.cos(out) # (M, D/2)
451
+
452
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
453
+ return emb
454
+
455
+
456
+ #################################################################################
457
+ #################################################################################
458
+ # DiT Configs #
459
+ #################################################################################
460
+
461
+
462
+ def DiT_XL(**kwargs):
463
+ return DiT(depth=28, hidden_size=1152, num_heads=16, **kwargs)
464
+
465
+
466
+ def DiT_L(**kwargs):
467
+ return DiT(depth=24, hidden_size=1024, num_heads=16, **kwargs)
468
+
469
+
470
+ def DiT_B(**kwargs):
471
+ return DiT(depth=12, hidden_size=768, num_heads=12, **kwargs)
472
+
473
+
474
+ def DiT_S(**kwargs):
475
+ return DiT(depth=6, hidden_size=256, num_heads=4, **kwargs)
476
+
477
+
478
+ DiT_models = {"DiT-XL": DiT_XL, "DiT-L": DiT_L, "DiT-B": DiT_B, "DiT-S": DiT_S}
nltk_data/taggers/averaged_perceptron_tagger_eng/averaged_perceptron_tagger_eng.classes.json ADDED
@@ -0,0 +1 @@
 
 
1
+ [".", "(", ")", ":", "''", "EX", "JJS", "WRB", "VBG", "VBP", "NN", "SYM", "VB", "UH", "NNPS", "NNP", "``", "$", "NNS", "JJR", "MD", "RP", "VBD", "DT", "POS", "RBR", ",", "VBZ", "PDT", "VBN", "WP$", "WDT", "WP", "PRP$", "CD", "IN", "#", "CC", "RB", "FW", "RBS", "PRP", "LS", "JJ", "TO"]
nltk_data/taggers/averaged_perceptron_tagger_eng/averaged_perceptron_tagger_eng.tagdict.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"four": "CD", "facilities": "NNS", "controversial": "JJ", "Until": "IN", "whose": "WP$", "under": "IN", "pact": "NN", "regional": "JJ", "GE": "NNP", "every": "DT", "GM": "NNP", "Moon": "NNP", "school": "NN", "companies": "NNS", "Nasdaq": "NNP", "Paul": "NNP", "Pinkerton": "NNP", "leaders": "NNS", "guidelines": "NNS", "Sales": "NNS", "machines": "NNS", "pace": "NN", "spokesman": "NN", "new": "JJ", "ever": "RB", "men": "NNS", "here": "RB", "protection": "NN", "studio": "NN", "active": "JJ", "100": "CD", "ventures": "NNS", "items": "NNS", "employees": "NNS", "credit": "NN", "analysts": "NNS", "criticism": "NN", "golden": "JJ", "Group": "NNP", "campaign": "NN", "St.": "NNP", "replace": "VB", "Also": "RB", "Health": "NNP", "costly": "JJ", "unit": "NN", "swings": "NNS", "would": "MD", "century": "NN", "June": "NNP", "music": "NN", "asset": "NN", "N.J.": "NNP", "until": "IN", "PaineWebber": "NNP", "Breeden": "NNP", "Ministry": "NNP", "successful": "JJ", "phone": "NN", "90": "CD", "circumstances": "NNS", "me": "PRP", "1990": "CD", "1993": "CD", "1992": "CD", "word": "NN", "1994": "CD", "rights": "NNS", "movies": "NNS", "already": "RB", "my": "PRP$", "example": "NN", "estate": "NN", "psyllium": "NN", "Hurricane": "NNP", "10,000": "CD", "Digital": "NNP", "totaled": "VBD", "recovery": "NN", "Journal": "NNP", "thousands": "NNS", "machine": "NN", "how": "WRB", "Jack": "NNP", "interview": "NN", "resignation": "NN", "minority": "NN", "L.": "NNP", "after": "IN", "modest": "JJ", "president": "NN", "law": "NN", "effective": "JJ", "Maxwell": "NNP", "Telerate": "NNP", "Another": "DT", "Trust": "NNP", "order": "NN", "operations": "NNS", "office": "NN", "expects": "VBZ", "presence": "NN", "His": "PRP$", "personal": "JJ", "expectations": "NNS", "Here": "RB", "production": "NN", "400": "CD", "Judge": "NNP", "weeks": "NNS", "Spain": "NNP", "eventually": "RB", "them": "PRP", "weakness": "NN", "Thomas": "NNP", "effects": "NNS", "they": "PRP", "schools": "NNS", "bank": "NN", "represents": "VBZ", "Indeed": "RB", "each": "DT", "went": "VBD", "bond": "NN", "financial": "JJ", "fairly": "RB", "series": "NN", "substantially": "RB", "lawyers": "NNS", "by": "IN", "network": "NN", "Chancellor": "NNP", "William": "NNP", "Icahn": "NNP", "size": "NN", "University": "NNP", "N.Y.": "NNP", "enormous": "JJ", "Monday": "NNP", "September": "NNP", "National": "NNP", "days": "NNS", "appeals": "NNS", "economists": "NNS", "another": "DT", "electronic": "JJ", "Congress": "NNP", "lawsuits": "NNS", "rates": "NNS", "too": "RB", "percentage": "NN", "ceiling": "NN", "took": "VBD", "budget": "NN", "acquisition": "NN", "fashion": "NN", "Chicago": "NNP", "Cray": "NNP", "talking": "VBG", "seed": "NN", "Instead": "RB", "dozen": "NN", "Then": "RB", "strength": "NN", "responsible": "JJ", "-": ":", "practices": "NNS", "Minister": "NNP", "They": "PRP", "Bank": "NNP", "unsecured": "JJ", "Jones": "NNP", "shall": "MD", "involving": "VBG", "letter": "NN", "Mobil": "NNP", "medical": "JJ", "competitors": "NNS", "consumer": "NN", "Its": "PRP$", "came": "VBD", "Union": "NNP", "meetings": "NNS", "ending": "VBG", "specialists": "NNS", "judges": "NNS", "Mixte": "NNP", "representing": "VBG", "exports": "NNS", "wide": "JJ", "13": "CD", "certificates": "NNS", "despite": "IN", "volatility": "NN", "countries": "NNS", "high-yield": "JJ", "Washington": "NNP", "bad": "JJ", "Qintex": "NNP", "movement": "NN", "secretary": "NN", "Gorbachev": "NNP", "discussions": "NNS", "John": "NNP", "said": "VBD", "capacity": "NN", "wage": "NN", "we": "PRP", "never": "RB", "terms": "NNS", "wo": "MD", "were": "VBD", "weak": "JJ", "however": "RB", "news": "NN", "debt": "NN", "Among": "IN", "country": "NN", "uncertainty": "NN", "against": "IN", "Thomson": "NNP", "players": "NNS", "Computer": "NNP", "games": "NNS", "faces": "VBZ", "tough": "JJ", "tons": "NNS", "Board": "NNP", "250": "CD", "line-item": "JJ", "conference": "NN", "C.": "NNP", "basis": "NN", "union": "NN", "three": "CD", "been": "VBN", "C$": "$", "commission": "NN", "beer": "NN", "interest": "NN", "life": "NN", "families": "NNS", "Conn.": "NNP", "Tokyo": "NNP", "drugs": "NNS", "Poland": "NNP", "Secretary": "NNP", "Co": "NNP", "publicly": "RB", "property": "NN", "Tuesday": "NNP", "seven": "CD", "On": "IN", "is": "VBZ", "it": "PRP", "expenses": "NNS", "player": "NN", "Bush": "NNP", "experts": "NNS", "in": "IN", "victims": "NNS", "if": "IN", "things": "NNS", "damages": "NNS", "big": "JJ", "President": "NNP", "several": "JJ", "independent": "JJ", "Institute": "NNP", "hand": "NN", "Angeles": "NNP", "Morris": "NNP", "ownership": "NN", "opportunity": "NN", "cycle": "NN", "RJR": "NNP", "programs": "NNS", "client": "NN", "the": "DT", "corporate": "JJ", "investments": "NNS", "agency": "NN", "just": "RB", "unemployment": "NN", "previous": "JJ", "adding": "VBG", "buyers": "NNS", "board": "NN", "Philip": "NNP", "has": "VBZ", "gave": "VBD", "Santa": "NNP", "James": "NNP", "possible": "JJ", "Chrysler": "NNP", "30": "CD", "highly": "RB", "55": "CD", "51": "CD", "50": "CD", "securities": "NNS", "offices": "NNS", "officer": "NN", "night": "NN", "security": "NN", "Pentagon": "NNP", "attorney": "NN", "old": "JJ", "people": "NNS", "Commission": "NNP", "election": "NN", "short-term": "JJ", "Lee": "NNP", "for": "IN", "comments": "NNS", "everything": "NN", "He": "PRP", "corn": "NN", "conventional": "JJ", "Georgia-Pacific": "NNP", "brokerage": "NN", "properties": "NNS", "dollars": "NNS", "months": "NNS", "magazine": "NN", "ensure": "VB", "afternoon": "NN", "efforts": "NNS", "Still": "RB", "slightly": "RB", "Fed": "NNP", "statements": "NNS", "facility": "NN", "civil": "JJ", "magazines": "NNS", "defendants": "NNS", "initial": "JJ", "legislation": "NN", "why": "WRB", "editor": "NN", "way": "NN", "NBC": "NNP", "was": "VBD", "war": "NN", "manufacturers": "NNS", "January": "NNP", "becoming": "VBG", "true": "JJ", "analyst": "NN", "counsel": "NN", "devices": "NNS", "County": "NNP", "Greenspan": "NNP", ".": ".", "Sir": "NNP", "evidence": "NN", "''": "''", "trip": "NN", "negotiations": "NNS", "LTV": "NNP", "Francisco": "NNP", "floor": "NN", "stake": "NN", "generally": "RB", "role": "NN", "models": "NNS", "Hunt": "NNP", "fell": "VBD", "authorities": "NNS", "'m": "VBP", "Mass.": "NNP", "weekend": "NN", "billion": "CD", "reorganization": "NN", "Estate": "NNP", "Charles": "NNP", "time": "NN", "serious": "JJ", "Moscow": "NNP", "profits": "NNS", "chain": "NN", "global": "JJ", "alternatives": "NNS", "manager": "NN", "battle": "NN", "certainly": "RB", "Sept.": "NNP", "Columbia": "NNP", "environment": "NN", "finally": "RB", "must": "MD", "1991": "CD", "choice": "NN", "liability": "NN", "trouble": "NN", "Jersey": "NNP", "room": "NN", "did": "VBD", "proposals": "NNS", "standards": "NNS", "speculation": "NN", "George": "NNP", "Rey": "NNP", "says": "VBZ", "trend": "NN", "M.": "NNP", "adds": "VBZ", "shares": "NNS", "Ford": "NNP", "current": "JJ", "goes": "VBZ", "international": "JJ", "falling": "VBG", "Nov.": "NNP", "transportation": "NN", "genes": "NNS", "water": "NN", "baseball": "NN", "groups": "NNS", "Ltd": "NNP", "appears": "VBZ", "Warner": "NNP", "healthy": "JJ", "guilty": "JJ", "trial": "NN", "usually": "RB", "Inc": "NNP", "studies": "NNS", "When": "WRB", "crisis": "NN", "market": "NN", "Australia": "NNP", "August": "NNP", "positive": "JJ", "sports": "NNS", "francs": "NNS", "today": "NN", "``": "``", "October": "NNP", "These": "DT", "downturn": "NN", "cases": "NNS", "effort": "NN", "currency": "NN", "car": "NN", "abortion": "NN", "Pacific": "NNP", "believes": "VBZ", "districts": "NNS", "can": "MD", "Our": "PRP$", "heart": "NN", "subsidies": "NNS", "1.2": "CD", "requirements": "NNS", "Akzo": "NNP", "1": "CD", "fourth": "JJ", "H.": "NNP", "Why": "WRB", "economy": "NN", "product": "NN", "information": "NN", "may": "MD", "membership": "NN", "date": "NN", "man": "NN", "natural": "JJ", "commodity": "NN", "futures": "NNS", "truck": "NN", "exclusive": "JJ", "indeed": "RB", "LIN": "NNP", "Hong": "NNP", "years": "NNS", "brain": "NN", "managers": "NNS", "White": "NNP", "still": "RB", "group": "NN", "Lehman": "NNP", "policy": "NN", "main": "JJ", "nation": "NN", "She": "PRP", "not": "RB", "R.": "NNP", "now": "RB", "provision": "NN", "nor": "CC", "term": "NN", "attorneys": "NNS", "Stanley": "NNP", "quarter": "NN", "significantly": "RB", "begun": "VBN", "year": "NN", "Kong": "NNP", "shown": "VBN", "space": "NN", "looking": "VBG", "investigation": "NN", "Bloomingdale": "NNP", "Commerce": "NNP", "cars": "NNS", "million": "CD", "possibility": "NN", "language": "NN", "7\\/8": "CD", "thing": "NN", "revenue": "NN", "There": "EX", "directly": "RB", "corporations": "NNS", "Hollywood": "NNP", "tomorrow": "NN", "millions": "NNS", "city": "NN", "given": "VBN", "district": "NN", "trillion": "CD", "Dow": "NNP", "anyone": "NN", "2": "CD", "SEC": "NNP", "white": "JJ", "gives": "VBZ", "a": "DT", "mostly": "RB", "season": "NN", "probably": "RB", "surged": "VBD", "than": "IN", "Inc.": "NNP", "11": "CD", "10": "CD", "television": "NN", "12": "CD", "15": "CD", "14": "CD", "17": "CD", "16": "CD", "19": "CD", "18": "CD", "spokeswoman": "NN", "officials": "NNS", "venture": "NN", "amid": "IN", "and": "CC", "Court": "NNP", "investors": "NNS", "Marcos": "NNP", "Philadelphia": "NNP", "sells": "VBZ", "any": "DT", "equipment": "NN", "intends": "VBZ", "performance": "NN", "Du": "NNP", "200": "CD", "normal": "JJ", "price": "NN", "remarks": "NNS", "D.": "NNP", "especially": "RB", "sale": "NN", "ways": "NNS", "senior": "JJ", "typically": "RB", "laws": "NNS", "rating": "NN", "commitments": "NNS", "aggressive": "JJ", "We": "PRP", "written": "VBN", "crime": "NN", "going": "VBG", "black": "JJ", "congressional": "JJ", "contracts": "NNS", "nearly": "RB", "morning": "NN", "miles": "NNS", "where": "WRB", "college": "NN", "Grand": "NNP", "concern": "NN", "mortgage": "NN", "farmers": "NNS", "federal": "JJ", "representatives": "NNS", "materials": "NNS", "weapons": "NNS", "between": "IN", "Mitsubishi": "NNP", "jobs": "NNS", "Johnson": "NNP", "U.S.": "NNP", "26": "CD", "Each": "DT", "article": "NN", "cities": "NNS", "acquiring": "VBG", "many": "JJ", "region": "NN", "according": "VBG", "contract": "NN", "holders": "NNS", "comes": "VBZ", "among": "IN", "cancer": "NN", "150": "CD", "period": "NN", ",": ",", "60": "CD", "considering": "VBG", "unusual": "JJ", "Calif.": "NNP", "Electric": "NNP", "But": "CC", "Lynch": "NNP", "500": "CD", "engine": "NN", "direction": "NN", "Analysts": "NNS", "former": "JJ", "those": "DT", "paying": "VBG", "To": "TO", "these": "DT", "consultant": "NN", "Reagan": "NNP", "cash": "NN", "n't": "RB", "policies": "NNS", "newspaper": "NN", "situation": "NN", "trader": "NN", "then": "RB", "metric": "JJ", "telephone": "NN", "Peters": "NNP", "technology": "NN", "Israel": "NNP", "media": "NNS", "same": "JJ", "events": "NNS", "status": "NN", "oil": "NN", "I": "PRP", "IRS": "NNP", "Toyota": "NNP", "Coors": "NNP", "director": "NN", "largely": "RB", "constitutional": "JJ", "roughly": "RB", "mortgages": "NNS", "Rep.": "NNP", "without": "IN", "In": "IN", "researchers": "NNS", "If": "IN", "summer": "NN", "United": "NNP", "Service": "NNP", "being": "VBG", "money": "NN", "actions": "NNS", "Daniel": "NNP", "announcement": "NN", "death": "NN", "rose": "VBD", "seems": "VBZ", "improvement": "NN", "4": "CD", "Although": "IN", "pill": "NN", "real": "JJ", "rules": "NNS", "Sachs": "NNP", "Ortega": "NNP", "inflation": "NN", "traffic": "NN", "using": "VBG", "'ve": "VBP", "annually": "RB", "audience": "NN", "London": "NNP", "retailers": "NNS", "fully": "RB", "Moreover": "RB", "Since": "IN", "competition": "NN", "Dr.": "NNP", "New": "NNP", "gross": "JJ", "legal": "JJ", "conservative": "JJ", "critical": "JJ", "deficit": "NN", "provides": "VBZ", "football": "NN", "scientific": "JJ", "power": "NN", "leadership": "NN", "manufacturer": "NN", "on": "IN", "central": "JJ", "S.A.": "NNP", "of": "IN", "industry": "NN", "Trade": "NNP", "airline": "NN", "or": "CC", "road": "NN", "outlook": "NN", "coupon": "NN", "instruments": "NNS", "image": "NN", "parties": "NNS", "your": "PRP$", "area": "NN", "Engelken": "NNP", "Bartlett": "NNP", "trying": "VBG", "with": "IN", "Guber": "NNP", "volume": "NN", "fraud": "NN", "House": "NNP", "pulp": "NN", "gone": "VBN", "ad": "NN", "certain": "JJ", "am": "VBP", "sales": "NNS", "Thursday": "NNP", "an": "DT", "at": "IN", "film": "NN", "USX": "NNP", "4.5": "CD", "again": "RB", "event": "NN", "field": "NN", "5": "CD", "you": "PRP", "Ross": "NNP", "Las": "NNP", "poor": "JJ", "Jaguar": "NNP", "students": "NNS", "includes": "VBZ", "important": "JJ", "coverage": "NN", "stocks": "NNS", "US$": "$", "assets": "NNS", "wife": "NN", "directors": "NNS", "Street": "NNP", "minister": "NN", "Canada": "NNP", "founder": "NN", "dollar": "NN", "5\\/8": "CD", "month": "NN", "settlement": "NN", "decisions": "NNS", "children": "NNS", "Brown": "NNP", "to": "TO", "program": "NN", "health": "NN", "lawmakers": "NNS", "activities": "NNS", "woman": "NN", "far": "RB", "difference": "NN", "`": "``", "cable": "NN", "--": ":", "large": "JJ", "small": "JJ", "rate": "NN", "lawyer": "NN", "investment": "NN", "HUD": "NNP", "Korea": "NNP", "consumers": "NNS", "Paribas": "NNP", "version": "NN", "scientists": "NNS", "Ogilvy": "NNP", "Bethlehem": "NNP", "full": "JJ", "hours": "NNS", "strong": "JJ", "thrift": "NN", "prosecutors": "NNS", "ahead": "RB", "houses": "NNS", "losses": "NNS", "social": "JJ", "action": "NN", "options": "NNS", "via": "IN", "family": "NN", "S.": "NNP", "establish": "VB", "Europe": "NNP", "shareholders": "NNS", "Dinkins": "NNP", "eye": "NN", "takes": "VBZ", "11\\/16": "CD", "Hewlett-Packard": "NNP", "two": "CD", "Corp": "NNP", "6": "CD", "taken": "VBN", "markets": "NNS", "Manville": "NNP", "Intel": "NNP", "division": "NN", "company": "NN", "producing": "VBG", "town": "NN", "keeping": "VBG", "hour": "NN", "nine": "CD", "history": "NN", "purchases": "NNS", "IBM": "NNP", "adviser": "NN", "share": "NN", "numbers": "NNS", "Thompson": "NNP", "sharp": "JJ", "!": ".", "huge": "JJ", "court": "NN", "goal": "NN", "rather": "RB", "Carpenter": "NNP", "earnings": "NNS", "plant": "NN", "different": "JJ", "response": "NN", "acquisitions": "NNS", "Mexico": "NNP", ")": ")", "banks": "NNS", "What": "WP", "soon": "RB", "paper": "NN", "committee": "NN", "signs": "NNS", "its": "PRP$", "Texas": "NNP", "24": "CD", "25": "CD", "style": "NN", "27": "CD", "20": "CD", "21": "CD", "22": "CD", "23": "CD", "28": "CD", "29": "CD", "actually": "RB", "systems": "NNS", "governments": "NNS", "might": "MD", "Moody": "NNP", "someone": "NN", "seeking": "VBG", "food": "NN", "Michael": "NNP", "bigger": "JJR", "easily": "RB", "always": "RB", "week": "NN", "everyone": "NN", "generation": "NN", "house": "NN", "energy": "NN", "reduce": "VB", "idea": "NN", "slowdown": "NN", "Joseph": "NNP", "advertisers": "NNS", "operation": "NN", "beyond": "IN", "insurance": "NN", "really": "RB", "E.": "NNP", "since": "IN", "temporary": "JJ", "research": "NN", "safety": "NN", "7": "CD", "According": "VBG", "EC": "NNP", "extraordinary": "JJ", "reason": "NN", "members": "NNS", "producers": "NNS", "owners": "NNS", "benefits": "NNS", "Boston": "NNP", "computers": "NNS", "threat": "NN", "pilots": "NNS", "major": "JJ", "Hugo": "NNP", "number": "NN", "feet": "NNS", "done": "VBN", "fees": "NNS", "story": "NN", "statement": "NN", "option": "NN", "relationship": "NN", "part": "NN", "kind": "NN", "grew": "VBD", "toward": "IN", "outstanding": "JJ", "Douglas": "NNP", "It": "PRP", "substantial": "JJ", "orders": "NNS", "ratings": "NNS", "majority": "NN", "internal": "JJ", "Drexel": "NNP", "chairman": "NN", "With": "IN", "75": "CD", "shareholder": "NN", "significant": "JJ", "70": "CD", "services": "NNS", "The": "DT", "extremely": "RB", "dealers": "NNS", "OTC": "NNP", "traditional": "JJ", "three-month": "JJ", "institutions": "NNS", "sector": "NN", "particularly": "RB", "session": "NN", "businesses": "NNS", "Poor": "NNP", "regulations": "NNS", "merger": "NN", "equity": "NN", "8": "CD", "Prime": "NNP", "his": "PRP$", "gains": "NNS", "While": "IN", "5,000": "CD", "closely": "RB", "During": "IN", "during": "IN", "him": "PRP", "merchandise": "NN", "six-month": "JJ", "J.": "NNP", "common": "JJ", "activity": "NN", "wrote": "VBD", "Chairman": "NNP", "For": "IN", "France": "NNP", "culture": "NN", "defense": "NN", "are": "VBP", "jury": "NN", "2.5": "CD", "#": "#", "movie": "NN", "currently": "RB", "case": "NN", "various": "JJ", "Sony": "NNP", "conditions": "NNS", "available": "JJ", "recently": "RB", "creating": "VBG", "dividends": "NNS", "attention": "NN", "Florida": "NNP", "succeed": "VB", "opposition": "NN", "dividend": "NN", "last": "JJ", "ANC": "NNP", "annual": "JJ", "foreign": "JJ", "connection": "NN", "became": "VBD", "long-term": "JJ", "Compaq": "NNP", "reasons": "NNS", "loan": "NN", "community": "NN", "simply": "RB", "throughout": "IN", "political": "JJ", "earthquake": "NN", "whom": "WP", "reduction": "NN", "California": "NNP", "treatment": "NN", "partly": "RB", "gas": "NN", "priced": "VBN", "brokers": "NNS", "prices": "NNS", "plants": "NNS", "bill": "NN", "elections": "NNS", "33": "CD", "31": "CD", "City": "NNP", "pound": "NN", "Italy": "NNP", "voters": "NNS", "cents": "NNS", "itself": "PRP", "seen": "VBN", "Co.": "NNP", "underwriters": "NNS", "virtually": "RB", "widely": "RB", "grand": "JJ", "9": "CD", "products": "NNS", "relatively": "RB", "development": "NN", "currencies": "NNS", "Allianz": "NNP", "affairs": "NNS", "yesterday": "NN", "moment": "NN", "levels": "NNS", "{": "(", "recent": "JJ", "Miller": "NNP", "person": "NN", "organization": "NN", "one-year": "JJ", "competitive": "JJ", "Boren": "NNP", "questions": "NNS", "world": "NN", "profitable": "JJ", "retirement": "NN", "$": "$", "over-the-counter": "JJ", "workers": "NNS", "source": "NN", "Germany": "NNP", "...": ":", "customers": "NNS", "Last": "JJ", "emergency": "NN", "Of": "IN", "Air": "NNP", "game": "NN", "necessary": "JJ", "projects": "NNS", "follows": "VBZ", "individuals": "NNS", "popular": "JJ", "often": "RB", "Gulf": "NNP", "some": "DT", "3\\/4": "CD", "economic": "JJ", "3\\/8": "CD", "Frank": "NNP", "decision": "NN", "transactions": "NNS", "quickly": "RB", "Massachusetts": "NNP", "be": "VB", "Brady": "NNP", "300": "CD", "agreement": "NN", "David": "NNP", "output": "NN", "abroad": "RB", "pipeline": "NN", "goods": "NNS", "anything": "NN", "Pont": "NNP", "Roy": "NNP", "ounce": "NN", "Committee": "NNP", "into": "IN", "within": "IN", "NEC": "NNP", "nothing": "NN", "primarily": "RB", "Quebecor": "NNP", "bankruptcy": "NN", ":": ":", "himself": "PRP", "vehicle": "NN", "Ms.": "NNP", "Ltd.": "NNP", "Switzerland": "NNP", "subsidiary": "NN", "line": "NN", "Bell": "NNP", "Africa": "NNP", "us": "PRP", "Thatcher": "NNP", "maturity": "NN", "'re": "VBP", "exploration": "NN", "Those": "DT", "similar": "JJ", "Perhaps": "RB", "Hampshire": "NNP", "Westinghouse": "NNP", "single": "JJ", "Edward": "NNP", "International": "NNP", "Manhattan": "NNP", "%": "NN", "May": "NNP", "politicians": "NNS", "Mae": "NNP", "income": "NN", "department": "NN", "AG": "NNP", "problems": "NNS", "helping": "VBG", "allowing": "VBG", "reinsurance": "NN", "sides": "NNS", "structure": "NN", "vice": "NN", "age": "NN", "vehicles": "NNS", "bankers": "NNS", "An": "DT", "At": "IN", "requires": "VBZ", "having": "VBG", "results": "NNS", "Department": "NNP", "issues": "NNS", "young": "JJ", "suits": "NNS", "citing": "VBG", "UAL": "NNP", "Not": "RB", "Now": "RB", "resources": "NNS", "P&G": "NNP", "automotive": "JJ", "continues": "VBZ", "Mrs.": "NNP", "putting": "VBG", "entire": "JJ", "positions": "NNS", "race": "NN", "smaller": "JJR", "crop": "NN", "Hutton": "NNP", "makers": "NNS", "index": "NN", "business": "NN", "giving": "VBG", "Alan": "NNP", "access": "NN", "volatile": "JJ", "firms": "NNS", "America": "NNP", "pushing": "VBG", "jointly": "RB", "others": "NNS", "great": "JJ", "38": "CD", "technical": "JJ", "Energy": "NNP", "larger": "JJR", "37": "CD", "35": "CD", "CBS": "NNP", "survey": "NN", "Motor": "NNP", "opinion": "NN", "residents": "NNS", "gene": "NN", "makes": "VBZ", "maker": "NN", "apple": "NN", "Robert": "NNP", "private": "JJ", "privately": "RB", "scandal": "NN", "from": "IN", "&": "CC", "few": "JJ", "Fe": "NNP", "year-ago": "JJ", "themselves": "PRP", "chip": "NN", "reflects": "VBZ", "Wednesday": "NNP", "sharply": "RB", "women": "NNS", "customer": "NN", "this": "DT", "clients": "NNS", "recession": "NN", "industrial": "JJ", "F.": "NNP", "Northern": "NNP", "tax": "NN", "Mr.": "NNP", "reserves": "NNS", "something": "NN", "Party": "NNP", "BellSouth": "NNP", "holds": "VBZ", "traders": "NNS", "instead": "RB", "stock": "NN", "ABC": "NNP", "Nissan": "NNP", "Terms": "NNS", "engineering": "NN", "lines": "NNS", "Community": "NNP", "Oct.": "NNP", "software": "NN", "six": "CD", "producer": "NN", "institutional": "JJ", "Smith": "NNP", "including": "VBG", "year-earlier": "JJ", "industries": "NNS", "Exchange": "NNP", "Brooks": "NNP", "labor": "NN", "willing": "JJ", "greater": "JJR", "auto": "NN", "practice": "NN", "investor": "NN", "day": "NN", "Supreme": "NNP", "San": "NNP", "bills": "NNS", "Corry": "NNP", "doing": "VBG", "books": "NNS", "Treasury": "NNP", "our": "PRP$", "80": "CD", "Unisys": "NNP", "entertainment": "NN", "critics": "NNS", "China": "NNP", "disclose": "VB", "This": "DT", "regulatory": "JJ", "could": "MD", "Lawson": "NNP", "succeeds": "VBZ", "powerful": "JJ", "strategic": "JJ", "owner": "NN", "management": "NN", "system": "NN", "relations": "NNS", "Coast": "NNP", "their": "PRP$", "Pilson": "NNP", "final": "JJ", "Association": "NNP", "interests": "NNS", "acquire": "VB", "environmental": "JJ", "chemicals": "NNS", "reflecting": "VBG", "steel": "NN", "colleagues": "NNS", "patients": "NNS", "Peter": "NNP", "creditors": "NNS", "1.4": "CD", "1.5": "CD", "1.6": "CD", "1.1": "CD", "Richard": "NNP", "1.3": "CD", "unchanged": "JJ", "partnership": "NN", "Other": "JJ", ";": ":", "apparently": "RB", "clearly": "RB", "Development": "NNP", "documents": "NNS", "Goldman": "NNP", "After": "IN", "able": "JJ", "instance": "NN", "which": "WDT", "unless": "IN", "who": "WP", "eight": "CD", "segment": "NN", "payment": "NN", "Reserve": "NNP", "so-called": "JJ", "Some": "DT", "MCA": "NNP", "1,000": "CD", "}": ")", "Saturday": "NNP", "fact": "NN", "Paris": "NNP", "Sansui": "NNP", "Chemical": "NNP", "Under": "IN", "portfolio": "NN", "economist": "NN", "decade": "NN", "staff": "NN", "partners": "NNS", "based": "VBN", "Meanwhile": "RB", "(": "(", "should": "MD", "candidates": "NNS", "York": "NNP", "employee": "NN", "local": "JJ", "bonds": "NNS", "familiar": "JJ", "120": "CD", "ones": "NNS", "words": "NNS", "exchanges": "NNS", "buyer": "NN", "chips": "NNS", "areas": "NNS", "Because": "IN", "trucks": "NNS", "course": "NN", "taxes": "NNS", "calling": "VBG", "Wall": "NNP", "she": "PRP", "Burnham": "NNP", "temporarily": "RB", "national": "JJ", "computer": "NN", "nuclear": "JJ", "state": "NN", "July": "NNP", "Sen.": "NNP", "ability": "NN", "agencies": "NNS", "job": "NN", "takeover": "NN", "approval": "NN", "problem": "NN", "declining": "VBG", "restrictions": "NNS", "drug": "NN", "1\\/2": "CD", "1\\/4": "CD", "1\\/8": "CD", "ca": "MD", "Los": "NNP", "addition": "NN", "genetic": "JJ", "agreements": "NNS", "proposal": "NN", "Toronto": "NNP", "Center": "NNP", "Navigation": "NNP", "And": "CC", "homes": "NNS", "unlike": "IN", "value": "NN", "will": "MD", "PLC": "NNP", "Delmed": "NNP", "owns": "VBZ", "almost": "RB", "thus": "RB", "site": "NN", "partner": "NN", "You": "PRP", "perhaps": "RB", "began": "VBD", "administration": "NN", "Bear": "NNP", "member": "NN", "when": "WRB", "parts": "NNS", "largest": "JJS", "units": "NNS", "party": "NN", "gets": "VBZ", "difficult": "JJ", "effect": "NN", "Mitchell": "NNP", "Houston": "NNP", "transaction": "NN", "Senate": "NNP", "wants": "VBZ", "350": "CD", "position": "NN", "Shearson": "NNP", "latest": "JJS", "stores": "NNS", "heavily": "RB", "increasingly": "RB", "domestic": "JJ", "obtain": "VB", "sources": "NNS", "Sunday": "NNP", "rooms": "NNS", "ads": "NNS", "Friday": "NNP", "book": "NN", "cosmetics": "NNS", "Despite": "IN", "By": "IN", "provisions": "NNS", "government": "NN", "five": "CD", "immediately": "RB", "loss": "NN", "England": "NNP", "Aug.": "NNP", "success": "NN", "B.": "NNP", "payments": "NNS", "amendment": "NN", "arbitrage": "NN", "Sun": "NNP", "Kidder": "NNP", "February": "NNP", "growth": "NN", "employment": "NN", "Singapore": "NNP", "broad": "JJ", "However": "RB", "does": "VBZ", "leader": "NN", "?": ".", "Baker": "NNP", "Rothschild": "NNP", "monetary": "JJ", "expansion": "NN", "Fujitsu": "NNP", "although": "IN", "loans": "NNS", "panel": "NN", "actual": "JJ", "debentures": "NNS", "December": "NNP", "Peabody": "NNP", "holdings": "NNS", "carries": "VBZ", "carrier": "NN", "Japan": "NNP", "executives": "NNS", "letters": "NNS", "previously": "RB", "warrants": "NNS", "Two": "CD", "getting": "VBG", "strategy": "NN", "utility": "NN", "1986": "CD", "1987": "CD", "1984": "CD", "1985": "CD", "1982": "CD", "additional": "JJ", "1980": "CD", "Lawrence": "NNP", "housing": "NN", "1988": "CD", "1989": "CD", "biggest": "JJS", "November": "NNP", "funds": "NNS", "brand": "NN", "but": "CC", "delivery": "NN", "construction": "NN", "highest": "JJS", "he": "PRP", "also": "RB", "Industrial": "NNP", "whether": "IN", "cells": "NNS", "Britain": "NNP", "distribution": "NN", "minutes": "NNS", "flight": "NN", "margins": "NNS", "mutual": "JJ", "compared": "VBN", "'ll": "MD", "48": "CD", "49": "CD", "46": "CD", "Jr.": "NNP", "44": "CD", "45": "CD", "42": "CD", "Yesterday": "NN", "40": "CD", "Volume": "NN", "other": "JJ", "details": "NNS", "Corp.": "NNP", "junk": "NN", "Like": "IN", "class": "NN", "March": "NNP", "April": "NNP", "chance": "NN", "Morgan": "NNP", "Act": "NNP", "factors": "NNS", "portion": "NN", "pension": "NN"}
nltk_data/taggers/averaged_perceptron_tagger_eng/averaged_perceptron_tagger_eng.weights.json ADDED
The diff for this file is too large to render. See raw diff
 
pyproject.toml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "diffusion-speech-360h"
3
+ version = "0.1.0"
4
+ description = "A simple diffusion-based text to speech model"
5
+ readme = "README.md"
6
+ requires-python = ">=3.12"
7
+ dependencies = [
8
+ "g2p-en>=2.1.0",
9
+ "gradio>=5.9.1",
10
+ "nltk>=3.9.1",
11
+ "soundfile>=0.12.1",
12
+ "torch>=2.5.1",
13
+ "vocos>=0.1.0",
14
+ ]
sample.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+
9
+ torch.backends.cuda.matmul.allow_tf32 = True
10
+ torch.backends.cudnn.allow_tf32 = True
11
+ import argparse
12
+ import os
13
+
14
+ import numpy as np
15
+ import torch
16
+ import yaml
17
+ from tqdm import tqdm
18
+
19
+ from diffusion import create_diffusion
20
+ from models import DiT_models
21
+
22
+
23
+ def find_model(model_name):
24
+ assert os.path.isfile(model_name), f"Could not find DiT checkpoint at {model_name}"
25
+ checkpoint = torch.load(model_name, map_location=lambda storage, loc: storage)
26
+ if "ema" in checkpoint: # supports checkpoints from train.py
27
+ print("Using EMA model")
28
+ checkpoint = checkpoint["ema"]
29
+ else:
30
+ print("Using model")
31
+ checkpoint = checkpoint["model"]
32
+ return checkpoint
33
+
34
+
35
+ def get_batch(
36
+ step, batch_size, seq_len, DEVICE, data_file, data_dim, data_mean, data_std
37
+ ):
38
+ # Load dataset from memmap file
39
+ arr = np.memmap(data_file, dtype=np.float16, mode="r")
40
+ arr = np.memmap(
41
+ data_file,
42
+ dtype=np.float16,
43
+ mode="r",
44
+ shape=(arr.shape[0] // (data_dim + 3), data_dim + 3),
45
+ )
46
+
47
+ # Create random number generator
48
+ rng = np.random.Generator(np.random.PCG64(seed=step))
49
+
50
+ # Generate start indices and convert to integer array
51
+ start_indices = rng.choice(
52
+ arr.shape[0] - seq_len, size=batch_size, replace=False
53
+ ).astype(np.int64)
54
+
55
+ # Create batch data array
56
+ batch_data = np.zeros((batch_size, seq_len, data_dim + 3), dtype=np.float16)
57
+ # Fill batch data one sequence at a time
58
+ for i, start_idx in enumerate(start_indices):
59
+ batch_data[i] = arr[start_idx : start_idx + seq_len]
60
+
61
+ # Extract features
62
+ x = batch_data[:, :, :data_dim].astype(np.float16)
63
+ x = np.moveaxis(x, 1, 2)
64
+ phone = batch_data[:, :, data_dim].astype(np.int32)
65
+ speaker_id = batch_data[:, :, data_dim + 1].astype(np.int32)
66
+ phone_kind = batch_data[:, :, data_dim + 2].astype(np.int32)
67
+
68
+ # convert to torch tensors
69
+ x = torch.from_numpy(x).to(DEVICE)
70
+ x = (x - data_mean) / data_std
71
+ phone = torch.from_numpy(phone).to(DEVICE)
72
+ speaker_id = torch.from_numpy(speaker_id).to(DEVICE)
73
+ phone_kind = torch.from_numpy(phone_kind).to(DEVICE)
74
+
75
+ return x, speaker_id, phone, phone_kind
76
+
77
+
78
+ def get_data(config_path, seed=0):
79
+ with open(config_path, "r") as f:
80
+ config = yaml.safe_load(f)
81
+
82
+ data_config = config["data"]
83
+ model_config = config["model"]
84
+ device = "cuda" if torch.cuda.is_available() else "cpu"
85
+
86
+ x, speaker_id, phone, phone_kind = get_batch(
87
+ seed,
88
+ 1,
89
+ seq_len=model_config["input_size"],
90
+ DEVICE=device,
91
+ data_file=data_config["data_path"],
92
+ data_dim=data_config["data_dim"],
93
+ data_mean=data_config["data_mean"],
94
+ data_std=data_config["data_std"],
95
+ )
96
+
97
+ return x, speaker_id, phone, phone_kind
98
+
99
+
100
+ def plot_samples(samples, x):
101
+ # Create figure and axis
102
+ fig, ax = plt.subplots(figsize=(20, 4))
103
+ plt.tight_layout()
104
+
105
+ # Function to update frame
106
+ def update(frame):
107
+ ax.clear()
108
+ ax.text(
109
+ 0.02,
110
+ 0.98,
111
+ f"{frame+1} / 1000",
112
+ transform=ax.transAxes,
113
+ verticalalignment="top",
114
+ color="black",
115
+ )
116
+ if samples[frame].shape[1] > 1:
117
+ im = ax.imshow(
118
+ samples[frame].cpu().numpy()[0],
119
+ origin="lower",
120
+ aspect="auto",
121
+ interpolation="none",
122
+ vmin=-5,
123
+ vmax=5,
124
+ )
125
+ return [im]
126
+ elif samples[frame].shape[1] == 1:
127
+ line1 = ax.plot(samples[frame].cpu().numpy()[0, 0])[0]
128
+ line2 = ax.plot(x.cpu().numpy()[0, 0])[0]
129
+ plt.ylim(-10, 10)
130
+ return [line1, line2]
131
+
132
+ # Create animation with progress bar
133
+ anim = animation.FuncAnimation(
134
+ fig,
135
+ update,
136
+ frames=tqdm(range(len(samples)), desc="Generating animation"),
137
+ interval=1000 / 60,
138
+ blit=True, # 24 fps
139
+ )
140
+
141
+ # Save as MP4
142
+ anim.save("animation.mp4", fps=60, extra_args=["-vcodec", "libx264"])
143
+ plt.close()
144
+
145
+
146
+ def sample(
147
+ config_path,
148
+ ckpt_path,
149
+ cfg_scale=4.0,
150
+ num_sampling_steps=1000,
151
+ seed=0,
152
+ speaker_id=None,
153
+ phone=None,
154
+ phone_kind=None,
155
+ ):
156
+ torch.manual_seed(seed)
157
+ torch.set_grad_enabled(False)
158
+ device = "cuda" if torch.cuda.is_available() else "cpu"
159
+
160
+ with open(config_path, "r") as f:
161
+ config = yaml.safe_load(f)
162
+
163
+ data_config = config["data"]
164
+ model_config = config["model"]
165
+
166
+ # Load model:
167
+ model = DiT_models[model_config["name"]](
168
+ input_size=model_config["input_size"],
169
+ embedding_vocab_size=model_config["embedding_vocab_size"],
170
+ learn_sigma=model_config["learn_sigma"],
171
+ in_channels=data_config["data_dim"],
172
+ ).to(device)
173
+
174
+ state_dict = find_model(ckpt_path)
175
+ model.load_state_dict(state_dict)
176
+ model.eval() # important!
177
+ diffusion = create_diffusion(str(num_sampling_steps))
178
+ n = 1
179
+ z = torch.randn(n, data_config["data_dim"], speaker_id.shape[1], device=device)
180
+
181
+ attn_mask = speaker_id[:, None, :] == speaker_id[:, :, None]
182
+ attn_mask = attn_mask.unsqueeze(1)
183
+ attn_mask = torch.cat([attn_mask, attn_mask], 0)
184
+ # Setup classifier-free guidance:
185
+ z = torch.cat([z, z], 0)
186
+ unconditional_value = model.y_embedder.unconditional_value
187
+ phone_null = torch.full_like(phone, unconditional_value)
188
+ speaker_id_null = torch.full_like(speaker_id, unconditional_value)
189
+ phone = torch.cat([phone, phone_null], 0)
190
+ speaker_id = torch.cat([speaker_id, speaker_id_null], 0)
191
+ phone_kind_null = torch.full_like(phone_kind, unconditional_value)
192
+ phone_kind = torch.cat([phone_kind, phone_kind_null], 0)
193
+ model_kwargs = dict(
194
+ phone=phone,
195
+ speaker_id=speaker_id,
196
+ phone_kind=phone_kind,
197
+ cfg_scale=cfg_scale,
198
+ attn_mask=attn_mask,
199
+ )
200
+
201
+ samples = diffusion.p_sample_loop(
202
+ model.forward_with_cfg,
203
+ z.shape,
204
+ z,
205
+ clip_denoised=False,
206
+ model_kwargs=model_kwargs,
207
+ progress=True,
208
+ device=device,
209
+ )
210
+ samples = [s.chunk(2, dim=0)[0] for s in samples] # Remove null class samples
211
+ return samples
212
+
213
+
214
+ if __name__ == "__main__":
215
+ parser = argparse.ArgumentParser()
216
+ parser.add_argument("--config", type=str, required=True)
217
+ parser.add_argument("--ckpt", type=str, required=True)
218
+ parser.add_argument("--cfg-scale", type=float, default=4.0)
219
+ parser.add_argument("--num-sampling-steps", type=int, default=1000)
220
+ parser.add_argument("--seed", type=int, default=0)
221
+ args = parser.parse_args()
222
+ x, speaker_id, phone, phone_kind = get_data(args.config, args.seed)
223
+ samples = sample(
224
+ args.config,
225
+ args.ckpt,
226
+ args.cfg_scale,
227
+ args.num_sampling_steps,
228
+ args.seed,
229
+ speaker_id,
230
+ phone,
231
+ phone_kind,
232
+ )
233
+ plot_samples(samples, x)
synthesize.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Synthesize a given text using the trained DiT models.
3
+ """
4
+
5
+ import json
6
+ import os
7
+
8
+ os.environ["NLTK_DATA"] = "nltk_data"
9
+ import torch
10
+ import yaml
11
+ from g2p_en import G2p
12
+ import soundfile as sf
13
+ from vocos import Vocos
14
+ from sample import sample
15
+
16
+
17
+ def synthesize(
18
+ text,
19
+ duration_model_config,
20
+ duration_model_checkpoint,
21
+ acoustic_model_config,
22
+ acoustic_model_checkpoint,
23
+ speaker_id,
24
+ cfg_scale=4.0,
25
+ num_sampling_steps=1000,
26
+ ):
27
+ """
28
+ Synthesize speech from text using trained DiT models.
29
+
30
+ Args:
31
+ text (str): Input text to synthesize
32
+ duration_model_config (str): Path to duration model config file
33
+ duration_model_checkpoint (str): Path to duration model checkpoint
34
+ acoustic_model_config (str): Path to acoustic model config file
35
+ acoustic_model_checkpoint (str): Path to acoustic model checkpoint
36
+ speaker_id (str): Speaker ID to use for synthesis
37
+ cfg_scale (float): Classifier-free guidance scale (default: 4.0)
38
+ num_sampling_steps (int): Number of sampling steps for diffusion (default: 1000)
39
+
40
+ Returns:
41
+ numpy.ndarray: Audio waveform array
42
+ int: Sample rate (24000)
43
+ """
44
+
45
+ print("Text:", text)
46
+
47
+ # Read duration model config
48
+ with open(duration_model_config, "r") as f:
49
+ duration_config = yaml.safe_load(f)
50
+
51
+ # Get data directory from data_path
52
+ data_dir = os.path.dirname(duration_config["data"]["data_path"])
53
+
54
+ # Read maps.json from same directory
55
+ with open(os.path.join(data_dir, "maps.json"), "r") as f:
56
+ maps = json.load(f)
57
+ phone_to_idx = maps["phone_to_idx"]
58
+ phone_kind_to_idx = maps["phone_kind_to_idx"]
59
+ speaker_id_to_idx = maps["speaker_id_to_idx"]
60
+
61
+ # Step 1: Text to phonemes
62
+ def text_to_phonemes(text, insert_empty=True):
63
+ g2p = G2p()
64
+ phonemes = g2p(text)
65
+ words = []
66
+ word = []
67
+ for p in phonemes:
68
+ if p == " ":
69
+ if len(word) > 0:
70
+ words.append(word)
71
+ word = []
72
+ else:
73
+ word.append(p)
74
+ if len(word) > 0:
75
+ words.append(word)
76
+
77
+ phones = []
78
+ phone_kinds = []
79
+ for word in words:
80
+ for i, p in enumerate(word):
81
+ if p in [",", ".", "!", "?", ";", ":"]:
82
+ p = "EMPTY"
83
+ elif p in phone_to_idx:
84
+ pass
85
+ else:
86
+ continue
87
+
88
+ if p == "EMPTY":
89
+ phone_kind = "EMPTY"
90
+ elif len(word) == 1:
91
+ phone_kind = "WORD"
92
+ elif i == 0:
93
+ phone_kind = "START"
94
+ elif i == len(word) - 1:
95
+ phone_kind = "END"
96
+ else:
97
+ phone_kind = "MIDDLE"
98
+
99
+ phones.append(p)
100
+ phone_kinds.append(phone_kind)
101
+
102
+ if insert_empty:
103
+ if phones[0] != "EMPTY":
104
+ phones.insert(0, "EMPTY")
105
+ phone_kinds.insert(0, "EMPTY")
106
+ if phones[-1] != "EMPTY":
107
+ phones.append("EMPTY")
108
+ phone_kinds.append("EMPTY")
109
+
110
+ return phones, phone_kinds
111
+
112
+ phonemes, phone_kinds = text_to_phonemes(text)
113
+ # Convert phonemes to indices
114
+ phoneme_indices = [phone_to_idx[p] for p in phonemes]
115
+ phone_kind_indices = [phone_kind_to_idx[p] for p in phone_kinds]
116
+ print("Phonemes:", phonemes)
117
+
118
+ # Step 2: Duration prediction
119
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
120
+ torch_phoneme_indices = torch.tensor(phoneme_indices)[None, :].long().to(device)
121
+ torch_speaker_id = torch.full_like(torch_phoneme_indices, int(speaker_id))
122
+ torch_phone_kind_indices = (
123
+ torch.tensor(phone_kind_indices)[None, :].long().to(device)
124
+ )
125
+
126
+ samples = sample(
127
+ duration_model_config,
128
+ duration_model_checkpoint,
129
+ cfg_scale=cfg_scale,
130
+ num_sampling_steps=num_sampling_steps,
131
+ seed=0,
132
+ speaker_id=torch_speaker_id,
133
+ phone=torch_phoneme_indices,
134
+ phone_kind=torch_phone_kind_indices,
135
+ )
136
+ phoneme_durations = samples[-1][0, 0]
137
+
138
+ # Step 3: Acoustic prediction
139
+ # First, we need to convert phoneme durations to number of frames per phoneme (min 1 frame)
140
+ SAMPLE_RATE = 24000
141
+ HOP_LENGTH = 256
142
+ N_FFT = 1024
143
+ N_MELS = 100
144
+ time_per_frame = HOP_LENGTH / SAMPLE_RATE
145
+ # convert predicted durations to raw durations using data mean and std in the config
146
+ if duration_config["data"]["normalize"]:
147
+ mean = duration_config["data"]["data_mean"]
148
+ std = duration_config["data"]["data_std"]
149
+ raw_durations = phoneme_durations * std + mean
150
+ else:
151
+ raw_durations = phoneme_durations
152
+
153
+ raw_durations = raw_durations.clamp(min=time_per_frame, max=1.0)
154
+ end_time = torch.cumsum(raw_durations, dim=0)
155
+ end_frame = end_time / time_per_frame
156
+ int_end_frame = end_frame.floor().int()
157
+ repeated_phoneme_indices = []
158
+ repeated_phone_kind_indices = []
159
+ for i in range(len(phonemes)):
160
+ repeated_phoneme_indices.extend(
161
+ [phoneme_indices[i]] * (int_end_frame[i] - len(repeated_phoneme_indices))
162
+ )
163
+ repeated_phone_kind_indices.extend(
164
+ [phone_kind_indices[i]]
165
+ * (int_end_frame[i] - len(repeated_phone_kind_indices))
166
+ )
167
+
168
+ torch_phoneme_indices = (
169
+ torch.tensor(repeated_phoneme_indices)[None, :].long().to(device)
170
+ )
171
+ torch_speaker_id = torch.full_like(torch_phoneme_indices, int(speaker_id))
172
+ torch_phone_kind_indices = (
173
+ torch.tensor(repeated_phone_kind_indices)[None, :].long().to(device)
174
+ )
175
+
176
+ samples = sample(
177
+ acoustic_model_config,
178
+ acoustic_model_checkpoint,
179
+ cfg_scale=cfg_scale,
180
+ num_sampling_steps=num_sampling_steps,
181
+ seed=0,
182
+ speaker_id=torch_speaker_id,
183
+ phone=torch_phoneme_indices,
184
+ phone_kind=torch_phone_kind_indices,
185
+ )
186
+ mel = samples[-1][0]
187
+ # compute raw mel if acoustic model normalize is true
188
+ acoustic_config = yaml.safe_load(open(acoustic_model_config, "r"))
189
+ if acoustic_config["data"]["normalize"]:
190
+ mean = acoustic_config["data"]["data_mean"]
191
+ std = acoustic_config["data"]["data_std"]
192
+ raw_mel = mel * std + mean
193
+ else:
194
+ raw_mel = mel
195
+
196
+ # Step 4: Vocoder
197
+ vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
198
+ audio = vocos.decode(raw_mel.cpu()[None, :, :]).squeeze().cpu().numpy()
199
+
200
+ return audio, SAMPLE_RATE
train_acoustic_dit_b.yaml ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ name: DiT-B
3
+ input_size: 2048
4
+ embedding_vocab_size: 1024
5
+ learn_sigma: true
6
+ optimization:
7
+ constant_memory: false
8
+ epochs: 1400
9
+ global_batch_size: 32
10
+ initial_input_size: 32
11
+ learning_rate: 1.0e-4
12
+ min_lr: 1.0e-5
13
+ warmup_iters: 10000
14
+ lr_decay_iters: 100000
15
+ decay_lr: true
16
+ weight_decay: 0.0
17
+ max_grad_norm: 20.0
18
+ betas:
19
+ beta1: 0.9
20
+ beta2: 0.999
21
+ loss:
22
+ num_timesteps: 1000
23
+ data:
24
+ data_path: acoustic.npy
25
+ data_dim: 100
26
+ data_std: 2.0
27
+ data_mean: -1.0
28
+ normalize: true
29
+ training:
30
+ enable_compile: true
31
+ use_bfloat16: true
32
+ use_block_mask: false
33
+ seed: 42
34
+ ckpt_every: 10_000
35
+ log_every: 100
36
+ results_dir: results/acoustic
37
+ resume_from_ckpt: null
38
+ wandb:
39
+ enable: true
40
+ project: diffusion-speech
train_duration_dit_s.yaml ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ name: DiT-S
3
+ input_size: 512
4
+ embedding_vocab_size: 1024
5
+ learn_sigma: true
6
+ optimization:
7
+ constant_memory: false
8
+ epochs: 1400
9
+ global_batch_size: 512
10
+ initial_input_size: 32
11
+ learning_rate: 1.0e-4
12
+ min_lr: 1.0e-5
13
+ warmup_iters: 10000
14
+ lr_decay_iters: 100000
15
+ decay_lr: true
16
+ weight_decay: 0.0
17
+ max_grad_norm: 20.0
18
+ betas:
19
+ beta1: 0.9
20
+ beta2: 0.999
21
+ loss:
22
+ num_timesteps: 1000
23
+ data:
24
+ data_path: duration.npy
25
+ data_dim: 1
26
+ data_std: 0.067776896
27
+ data_mean: 0.08663661
28
+ normalize: true
29
+ training:
30
+ enable_compile: true
31
+ use_bfloat16: true
32
+ use_block_mask: false
33
+ seed: 42
34
+ ckpt_every: 10_000
35
+ log_every: 100
36
+ results_dir: results/duration
37
+ resume_from_ckpt: null
38
+ wandb:
39
+ enable: true
40
+ project: diffusion-speech
uv.lock ADDED
The diff for this file is too large to render. See raw diff