|
import os |
|
os.environ["OMP_NUM_THREADS"] = "1" |
|
import sys |
|
|
|
import glob |
|
import cv2 |
|
import tqdm |
|
import numpy as np |
|
from data_gen.utils.mp_feature_extractors.face_landmarker import MediapipeLandmarker |
|
from utils.commons.multiprocess_utils import multiprocess_run_tqdm |
|
import warnings |
|
warnings.filterwarnings('ignore') |
|
|
|
import random |
|
random.seed(42) |
|
|
|
import pickle |
|
import json |
|
import gzip |
|
from typing import Any |
|
|
|
def load_file(filename, is_gzip: bool = False, is_json: bool = False) -> Any: |
|
if is_json: |
|
if is_gzip: |
|
with gzip.open(filename, "r", encoding="utf-8") as f: |
|
loaded_object = json.load(f) |
|
return loaded_object |
|
else: |
|
with open(filename, "r", encoding="utf-8") as f: |
|
loaded_object = json.load(f) |
|
return loaded_object |
|
else: |
|
if is_gzip: |
|
with gzip.open(filename, "rb") as f: |
|
loaded_object = pickle.load(f) |
|
return loaded_object |
|
else: |
|
with open(filename, "rb") as f: |
|
loaded_object = pickle.load(f) |
|
return loaded_object |
|
|
|
def save_file(filename, content, is_gzip: bool = False, is_json: bool = False) -> None: |
|
if is_json: |
|
if is_gzip: |
|
with gzip.open(filename, "w", encoding="utf-8") as f: |
|
json.dump(content, f) |
|
else: |
|
with open(filename, "w", encoding="utf-8") as f: |
|
json.dump(content, f) |
|
else: |
|
if is_gzip: |
|
with gzip.open(filename, "wb") as f: |
|
pickle.dump(content, f) |
|
else: |
|
with open(filename, "wb") as f: |
|
pickle.dump(content, f) |
|
|
|
face_landmarker = None |
|
|
|
def extract_lms_mediapipe_job(img): |
|
if img is None: |
|
return None |
|
global face_landmarker |
|
if face_landmarker is None: |
|
face_landmarker = MediapipeLandmarker() |
|
lm478 = face_landmarker.extract_lm478_from_img(img) |
|
return lm478 |
|
|
|
def extract_landmark_job(img_name): |
|
try: |
|
|
|
|
|
|
|
out_name = img_name.replace("/images_512/", "/lms_2d/").replace(".png","_lms.npy") |
|
if os.path.exists(out_name): |
|
print("out exists, skip...") |
|
return |
|
try: |
|
os.makedirs(os.path.dirname(out_name), exist_ok=True) |
|
except: |
|
pass |
|
img = cv2.imread(img_name)[:,:,::-1] |
|
|
|
if img is not None: |
|
lm468 = extract_lms_mediapipe_job(img) |
|
if lm468 is not None: |
|
np.save(out_name, lm468) |
|
|
|
except Exception as e: |
|
print(e) |
|
pass |
|
|
|
def out_exist_job(img_name): |
|
out_name = img_name.replace("/images_512/", "/lms_2d/").replace(".png","_lms.npy") |
|
if os.path.exists(out_name): |
|
return None |
|
else: |
|
return img_name |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
import argparse, glob, tqdm, random |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--img_dir", default='/home/tiger/datasets/raw/FFHQ/images_512/') |
|
parser.add_argument("--ds_name", default='FFHQ') |
|
parser.add_argument("--num_workers", default=64, type=int) |
|
parser.add_argument("--process_id", default=0, type=int) |
|
parser.add_argument("--total_process", default=1, type=int) |
|
parser.add_argument("--reset", action='store_true') |
|
parser.add_argument("--img_names_file", default="img_names.pkl", type=str) |
|
parser.add_argument("--load_img_names", action="store_true") |
|
|
|
args = parser.parse_args() |
|
print(f"args {args}") |
|
img_dir = args.img_dir |
|
img_names_file = os.path.join(img_dir, args.img_names_file) |
|
if args.load_img_names: |
|
img_names = load_file(img_names_file) |
|
print(f"load image names from {img_names_file}") |
|
else: |
|
if args.ds_name == 'FFHQ_MV': |
|
img_name_pattern1 = os.path.join(img_dir, "ref_imgs/*.png") |
|
img_names1 = glob.glob(img_name_pattern1) |
|
img_name_pattern2 = os.path.join(img_dir, "mv_imgs/*.png") |
|
img_names2 = glob.glob(img_name_pattern2) |
|
img_names = img_names1 + img_names2 |
|
img_names = sorted(img_names) |
|
elif args.ds_name == 'FFHQ': |
|
img_name_pattern = os.path.join(img_dir, "*.png") |
|
img_names = glob.glob(img_name_pattern) |
|
img_names = sorted(img_names) |
|
elif args.ds_name == "PanoHeadGen": |
|
|
|
img_name_patterns = ["ref/*/*.png"] |
|
img_names = [] |
|
for img_name_pattern in img_name_patterns: |
|
img_name_pattern_full = os.path.join(img_dir, img_name_pattern) |
|
img_names_part = glob.glob(img_name_pattern_full) |
|
img_names.extend(img_names_part) |
|
img_names = sorted(img_names) |
|
|
|
|
|
if not args.load_img_names: |
|
save_file(img_names_file, img_names) |
|
print(f"save image names in {img_names_file}") |
|
|
|
print(f"total images number: {len(img_names)}") |
|
|
|
|
|
process_id = args.process_id |
|
total_process = args.total_process |
|
if total_process > 1: |
|
assert process_id <= total_process -1 |
|
num_samples_per_process = len(img_names) // total_process |
|
if process_id == total_process: |
|
img_names = img_names[process_id * num_samples_per_process : ] |
|
else: |
|
img_names = img_names[process_id * num_samples_per_process : (process_id+1) * num_samples_per_process] |
|
|
|
|
|
|
|
|
|
|
|
print(f"todo_image {img_names[:10]}") |
|
print(f"processing images number in this process: {len(img_names)}") |
|
|
|
|
|
|
|
|
|
if args.num_workers == 1: |
|
index = 0 |
|
for img_name in tqdm.tqdm(img_names, desc=f"Root process {args.process_id}: extracting MP-based landmark2d"): |
|
try: |
|
extract_landmark_job(img_name) |
|
except Exception as e: |
|
print(e) |
|
pass |
|
if index % max(1, int(len(img_names) * 0.003)) == 0: |
|
print(f"processed {index} / {len(img_names)}") |
|
sys.stdout.flush() |
|
index += 1 |
|
else: |
|
for i, res in multiprocess_run_tqdm( |
|
extract_landmark_job, img_names, |
|
num_workers=args.num_workers, |
|
desc=f"Root {args.process_id}: extracing MP-based landmark2d"): |
|
|
|
print(f"processed {i+1} / {len(img_names)}") |
|
sys.stdout.flush() |
|
print(f"Root {args.process_id}: Finished extracting.") |