|
from mivolo_model import MiVOLOModel |
|
import torch |
|
import torchvision.transforms as transforms |
|
from ultralytics import YOLO |
|
from PIL import Image |
|
import numpy as np |
|
import os |
|
import requests |
|
|
|
def download_files_to_cache(urls, file_names, cache_dir_name="age_estimation"): |
|
def download_file(url, save_path): |
|
response = requests.get(url, stream=True) |
|
response.raise_for_status() |
|
|
|
with open(save_path, 'wb') as file: |
|
for chunk in response.iter_content(chunk_size=8192): |
|
file.write(chunk) |
|
print(f"File downloaded and saved to {save_path}") |
|
|
|
|
|
cache_dir = os.path.join(os.path.expanduser("~"), ".cache", cache_dir_name) |
|
|
|
|
|
os.makedirs(cache_dir, exist_ok=True) |
|
|
|
|
|
for url, file_name in zip(urls, file_names): |
|
save_path = os.path.join(cache_dir, file_name) |
|
if not os.path.exists(save_path): |
|
print(f"File {file_name} does not exist. Downloading...") |
|
download_file(url, save_path) |
|
else: |
|
print(f"File {file_name} already exists at {save_path}") |
|
|
|
|
|
urls = [ |
|
"https://huggingface.co/hungdang1610/estimate_age/resolve/main/models/best_model_weights_10.pth?download=true", |
|
"https://huggingface.co/hungdang1610/estimate_age/resolve/main/models/yolov8x_person_face.pt?download=true" |
|
] |
|
|
|
|
|
file_names = [ |
|
"best_model_weights_10.pth", |
|
"yolov8x_person_face.pt" |
|
] |
|
model_path = os.path.join(os.path.expanduser("~"), ".cache/age_estimation/best_model_weights_10.pth") |
|
detection_path = os.path.join(os.path.expanduser("~"), ".cache/age_estimation/yolov8x_person_face.pt") |
|
|
|
download_files_to_cache(urls, file_names) |
|
|
|
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) |
|
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) |
|
MEAN_TRAIN = 36.64 |
|
STD_TRAIN = 21.74 |
|
model = MiVOLOModel( |
|
layers=(4, 4, 8, 2), |
|
img_size=224, |
|
in_chans=6, |
|
num_classes=3, |
|
patch_size=8, |
|
stem_hidden_dim=64, |
|
embed_dims=(192, 384, 384, 384), |
|
num_heads=(6, 12, 12, 12), |
|
).to('cpu') |
|
state = torch.load(model_path, map_location="cpu") |
|
model.load_state_dict(state, strict=True) |
|
|
|
transform_infer = transforms.Compose([ |
|
transforms.Resize((224, 224), interpolation=transforms.InterpolationMode.BICUBIC), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
|
]) |
|
detector = YOLO(detection_path) |
|
def chunk_then_stack(image): |
|
|
|
image_np = np.array(image) |
|
results = detector.predict(image_np, conf=0.35) |
|
for result in results: |
|
boxes = result.boxes |
|
|
|
|
|
face_coords = [None, None, None, None] |
|
person_coords = [None, None, None, None] |
|
|
|
|
|
for i, box in enumerate(boxes.xyxy): |
|
cls = int(boxes.cls[i].item()) |
|
x_min, y_min, x_max, y_max = map(int, box.tolist()) |
|
|
|
|
|
if cls == 1: |
|
face_coords = [x_min, y_min, x_max, y_max] |
|
elif cls == 0: |
|
person_coords = [x_min, y_min, x_max, y_max] |
|
|
|
return face_coords, person_coords |
|
|
|
|
|
|
|
def tranfer_image(image): |
|
|
|
face_coords, person_coords = chunk_then_stack(image) |
|
face_image = image.crop((int(face_coords[0]), int(face_coords[1]), int(face_coords[2]), int(face_coords[3]))) |
|
|
|
person_image = image.crop((int(person_coords[0]), int(person_coords[1]), int(person_coords[2]), int(person_coords[3]))) |
|
|
|
|
|
face_image = face_image.resize((224, 224)) |
|
person_image = person_image.resize((224, 224)) |
|
face_image = transform_infer(face_image) |
|
person_image = transform_infer(person_image) |
|
|
|
|
|
image_ = torch.cat((face_image, person_image), dim=0) |
|
return image_.unsqueeze(0) |
|
|
|
image = Image.open("1.jpg").convert('RGB') |
|
image_ = tranfer_image(image) |
|
print(image_.shape) |
|
import time |
|
start_time = time.time() |
|
output = model(image_) |
|
output_mse = output[:, 2] |
|
predicted_age = output_mse.item() *STD_TRAIN + MEAN_TRAIN |
|
print("inference time: ", time.time() - start_time) |
|
print("predicted_age: ", predicted_age) |