File size: 3,310 Bytes
7cdd981
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
"""
 # Copyright 2020 Adobe
 # All Rights Reserved.
 
 # NOTICE: Adobe permits you to use, modify, and distribute this file in
 # accordance with the terms of the Adobe license agreement accompanying
 # it.
 
"""

import sys
sys.path.append('thirdparty/AdaptiveWingLoss')
import os, glob
import numpy as np
import cv2
import argparse
from src.dataset.image_translation import landmark_extraction, landmark_image_to_data
from approaches.train_image_translation import Image_translation_block
import platform
import torch


if platform.release() == '4.4.0-83-generic':
    src_dir = r'/mnt/ntfs/Dataset/TalkingToon/VoxCeleb2_imagetranslation/raw_fl3d'
    mp4_dir = r'/mnt/ntfs/Dataset/VoxCeleb2/train_set/dev/mp4'
    jpg_dir = r'img_output'
    ckpt_dir = r'img_output'
    log_dir = r'img_output'
else: # 3.10.0-957.21.2.el7.x86_64
    # root = r'/mnt/nfs/scratch1/yangzhou/VoxCeleb2_imagetranslation'
    root = r'/mnt/nfs/scratch1/yangzhou/PreprocessedVox_imagetranslation'
    src_dir = os.path.join(root, 'raw_fl3d')
    # mp4_dir = r'/mnt/nfs/work1/kalo/yangzhou/VoxCeleb2/train_set/dev/mp4'
    mp4_dir = r'/mnt/nfs/scratch1/yangzhou/PreprocessedVox_mp4'
    jpg_dir = os.path.join(root, 'tmp_v')
    ckpt_dir = os.path.join(root, 'ckpt')
    log_dir = os.path.join(root, 'log')

''' Step 1. Data preparation '''
# landmark extraction
# landmark_extraction(int(sys.argv[1]), int(sys.argv[2]))

# save image data ahead -> saved file too large, will create data online
# landmark_image_to_data(0, 0, show=False)

''' Step 2. Train the network '''
parser = argparse.ArgumentParser()
parser.add_argument('--nepoch', type=int, default=150, help='number of epochs to train for')
parser.add_argument('--batch_size', type=int, default=8, help='batch size')
parser.add_argument('--num_frames', type=int, default=1, help='')
parser.add_argument('--num_workers', type=int, default=4, help='number of frames extracted from each video')
parser.add_argument('--lr', type=float, default=0.0001, help='')

parser.add_argument('--write', default=False, action='store_true')
parser.add_argument('--train', default=False, action='store_true')
parser.add_argument('--name', type=str, default='tmp')
parser.add_argument('--test_speed', default=False, action='store_true')

parser.add_argument('--jpg_dir', type=str, default=jpg_dir)
parser.add_argument('--ckpt_dir', type=str, default=ckpt_dir)
parser.add_argument('--log_dir', type=str, default=log_dir)

parser.add_argument('--jpg_freq', type=int, default=50, help='')
parser.add_argument('--ckpt_last_freq', type=int, default=1000, help='')
parser.add_argument('--ckpt_epoch_freq', type=int, default=1, help='')

parser.add_argument('--load_G_name', type=str, default='')
parser.add_argument('--use_vox_dataset', type=str, default='raw')


parser.add_argument('--add_audio_in', default=False, action='store_true')
parser.add_argument('--comb_fan_awing', default=False, action='store_true')
parser.add_argument('--fan_2or3D', type=str, default='3D')

parser.add_argument('--single_test', type=str, default='')

opt_parser = parser.parse_args()


model = Image_translation_block(opt_parser)

if(opt_parser.single_test != ''):
    with torch.no_grad():
        model.single_test()

if(opt_parser.train):
    model.train()
else:
    with torch.no_grad():
        model.test()