|
|
|
|
|
|
|
import os |
|
from typing import List |
|
|
|
import torch |
|
from PIL import Image |
|
from fastapi import HTTPException |
|
from mtc_api_utils.base_model import MLBaseModel |
|
from mtc_api_utils.init_api import download_if_not_exists |
|
from torchvision import transforms |
|
|
|
from aesthetics_model.api_types import AestheticsModelResponse |
|
from aesthetics_model.config import AestheticsModelConfig |
|
from aesthetics_model.model.vision_transformer import vit_large_patch16_224_in21k |
|
|
|
|
|
class AestheticModel(MLBaseModel): |
|
model = None |
|
transform = None |
|
device = None |
|
model_checkpoint_path = None |
|
|
|
def init_model(self): |
|
print("Initializing Aesthetic Model...") |
|
|
|
self.model_checkpoint_path = download_if_not_exists( |
|
artifact_url=os.path.join(AestheticsModelConfig.model_base_url, AestheticsModelConfig.model_url), |
|
download_dir=AestheticsModelConfig.model_checkpoint_path, |
|
polybox_auth=AestheticsModelConfig.polybox_credentials, |
|
is_tar=False, |
|
) |
|
print("Model downloaded, loading to device..") |
|
|
|
self.model, self.transform, self.device = self.load_model() |
|
print("Model initialized") |
|
|
|
def is_ready(self): |
|
return self.model is not None |
|
|
|
def inference(self, images: List[Image.Image]) -> List[AestheticsModelResponse]: |
|
if not self.is_ready(): |
|
raise HTTPException(status_code=503, detail="Model is not ready") |
|
|
|
results = [] |
|
for image in images: |
|
im = self.validate_and_resize_image(image, max_size=AestheticsModelConfig.max_image_size) |
|
image_tensor = self.transform(im).unsqueeze(0).to(self.device) |
|
with torch.no_grad(): |
|
image_embedding = self.model.forward_features(image_tensor) |
|
score = self.model.head(image_embedding).squeeze().to("cpu") |
|
result: AestheticsModelResponse = AestheticsModelResponse( |
|
aesthetics_score=float(score), |
|
aesthetics_embedding=image_embedding.to("cpu").tolist(), |
|
|
|
success=True |
|
) |
|
results.append(result) |
|
return results |
|
|
|
def load_model(self): |
|
|
|
use_cuda = torch.cuda.is_available() |
|
device = torch.device("cuda:0" if use_cuda else "cpu") |
|
print(f"Running on {device}") |
|
|
|
model = vit_large_patch16_224_in21k() |
|
model.reset_classifier(num_classes=1) |
|
model.load_state_dict(torch.load(self.model_checkpoint_path, map_location=device)) |
|
|
|
model.eval() |
|
model.to(device) |
|
|
|
transform = transforms.Compose( |
|
[transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])] |
|
) |
|
|
|
return model, transform, device |
|
|
|
def validate_and_resize_image(self, image: Image.Image, max_size: int): |
|
"""Check if an image is to big, if it is too big resize it. Transform to RGB if B&W image. |
|
|
|
Args: |
|
img : PIL Image |
|
max_size : int |
|
""" |
|
if not image.mode == "RGB": |
|
image = image.convert(mode="RGB") |
|
|
|
size_factor = (max_size + 0.0) / max(image.size) |
|
if size_factor < 1: |
|
new_h = round(image.size[0] * size_factor) |
|
new_w = round(image.size[1] * size_factor) |
|
image = image.resize((new_w, new_h), resample=Image.HAMMING) |
|
return image |
|
|
|
|
|
if __name__ == '__main__': |
|
model = AestheticModel() |
|
model.__wait_until_ready__() |
|
|
|
example_images = ["../tests/test_image.png"] |
|
|
|
|
|
images = [model.model(Image.open(image)) for image in example_images] |
|
|
|
model_result = model.inference(images=images) |
|
print(model_result) |
|
|