ameerazam08's picture
Upload folder using huggingface_hub
e34aada verified
import os
os.environ["OMP_NUM_THREADS"] = "1"
import sys
import glob
import cv2
import tqdm
import numpy as np
from data_gen.utils.mp_feature_extractors.face_landmarker import MediapipeLandmarker
from utils.commons.multiprocess_utils import multiprocess_run_tqdm
import warnings
warnings.filterwarnings('ignore')
import random
random.seed(42)
import pickle
import json
import gzip
from typing import Any
def load_file(filename, is_gzip: bool = False, is_json: bool = False) -> Any:
if is_json:
if is_gzip:
with gzip.open(filename, "r", encoding="utf-8") as f:
loaded_object = json.load(f)
return loaded_object
else:
with open(filename, "r", encoding="utf-8") as f:
loaded_object = json.load(f)
return loaded_object
else:
if is_gzip:
with gzip.open(filename, "rb") as f:
loaded_object = pickle.load(f)
return loaded_object
else:
with open(filename, "rb") as f:
loaded_object = pickle.load(f)
return loaded_object
def save_file(filename, content, is_gzip: bool = False, is_json: bool = False) -> None:
if is_json:
if is_gzip:
with gzip.open(filename, "w", encoding="utf-8") as f:
json.dump(content, f)
else:
with open(filename, "w", encoding="utf-8") as f:
json.dump(content, f)
else:
if is_gzip:
with gzip.open(filename, "wb") as f:
pickle.dump(content, f)
else:
with open(filename, "wb") as f:
pickle.dump(content, f)
face_landmarker = None
def extract_lms_mediapipe_job(img):
if img is None:
return None
global face_landmarker
if face_landmarker is None:
face_landmarker = MediapipeLandmarker()
lm478 = face_landmarker.extract_lm478_from_img(img)
return lm478
def extract_landmark_job(img_name):
try:
# if img_name == 'datasets/PanoHeadGen/raw/images/multi_view/chunk_0/seed0000002.png':
# print(1)
# input()
out_name = img_name.replace("/images_512/", "/lms_2d/").replace(".png","_lms.npy")
if os.path.exists(out_name):
print("out exists, skip...")
return
try:
os.makedirs(os.path.dirname(out_name), exist_ok=True)
except:
pass
img = cv2.imread(img_name)[:,:,::-1]
if img is not None:
lm468 = extract_lms_mediapipe_job(img)
if lm468 is not None:
np.save(out_name, lm468)
# print("Hahaha, solve one item!!!")
except Exception as e:
print(e)
pass
def out_exist_job(img_name):
out_name = img_name.replace("/images_512/", "/lms_2d/").replace(".png","_lms.npy")
if os.path.exists(out_name):
return None
else:
return img_name
# def get_todo_img_names(img_names):
# todo_img_names = []
# for i, res in multiprocess_run_tqdm(out_exist_job, img_names, num_workers=64):
# if res is not None:
# todo_img_names.append(res)
# return todo_img_names
if __name__ == '__main__':
import argparse, glob, tqdm, random
parser = argparse.ArgumentParser()
parser.add_argument("--img_dir", default='/home/tiger/datasets/raw/FFHQ/images_512/')
parser.add_argument("--ds_name", default='FFHQ')
parser.add_argument("--num_workers", default=64, type=int)
parser.add_argument("--process_id", default=0, type=int)
parser.add_argument("--total_process", default=1, type=int)
parser.add_argument("--reset", action='store_true')
parser.add_argument("--img_names_file", default="img_names.pkl", type=str)
parser.add_argument("--load_img_names", action="store_true")
args = parser.parse_args()
print(f"args {args}")
img_dir = args.img_dir
img_names_file = os.path.join(img_dir, args.img_names_file)
if args.load_img_names:
img_names = load_file(img_names_file)
print(f"load image names from {img_names_file}")
else:
if args.ds_name == 'FFHQ_MV':
img_name_pattern1 = os.path.join(img_dir, "ref_imgs/*.png")
img_names1 = glob.glob(img_name_pattern1)
img_name_pattern2 = os.path.join(img_dir, "mv_imgs/*.png")
img_names2 = glob.glob(img_name_pattern2)
img_names = img_names1 + img_names2
img_names = sorted(img_names)
elif args.ds_name == 'FFHQ':
img_name_pattern = os.path.join(img_dir, "*.png")
img_names = glob.glob(img_name_pattern)
img_names = sorted(img_names)
elif args.ds_name == "PanoHeadGen":
# img_name_patterns = ["ref/*/*.png", "multi_view/*/*.png", "reverse/*/*.png"]
img_name_patterns = ["ref/*/*.png"]
img_names = []
for img_name_pattern in img_name_patterns:
img_name_pattern_full = os.path.join(img_dir, img_name_pattern)
img_names_part = glob.glob(img_name_pattern_full)
img_names.extend(img_names_part)
img_names = sorted(img_names)
# save image names
if not args.load_img_names:
save_file(img_names_file, img_names)
print(f"save image names in {img_names_file}")
print(f"total images number: {len(img_names)}")
process_id = args.process_id
total_process = args.total_process
if total_process > 1:
assert process_id <= total_process -1
num_samples_per_process = len(img_names) // total_process
if process_id == total_process:
img_names = img_names[process_id * num_samples_per_process : ]
else:
img_names = img_names[process_id * num_samples_per_process : (process_id+1) * num_samples_per_process]
# if not args.reset:
# img_names = get_todo_img_names(img_names)
print(f"todo_image {img_names[:10]}")
print(f"processing images number in this process: {len(img_names)}")
# print(f"todo images number: {len(img_names)}")
# input()
# exit()
if args.num_workers == 1:
index = 0
for img_name in tqdm.tqdm(img_names, desc=f"Root process {args.process_id}: extracting MP-based landmark2d"):
try:
extract_landmark_job(img_name)
except Exception as e:
print(e)
pass
if index % max(1, int(len(img_names) * 0.003)) == 0:
print(f"processed {index} / {len(img_names)}")
sys.stdout.flush()
index += 1
else:
for i, res in multiprocess_run_tqdm(
extract_landmark_job, img_names,
num_workers=args.num_workers,
desc=f"Root {args.process_id}: extracing MP-based landmark2d"):
# if index % max(1, int(len(img_names) * 0.003)) == 0:
print(f"processed {i+1} / {len(img_names)}")
sys.stdout.flush()
print(f"Root {args.process_id}: Finished extracting.")