|
--- |
|
tags: |
|
- model_hub_mixin |
|
- pytorch_model_hub_mixin |
|
--- |
|
|
|
This is fashion image feature extractor model. |
|
|
|
# 1. Model Architecture |
|
|
|
I used [microsoft/swin-base-patch4-window7-224](https://huggingface.co/microsoft/swin-base-patch4-window7-224) for base image encoder model. Just added a 128 size fully connected layer to lower embedding size. |
|
The dataset used anchor (product areas detected from posts) - positive (product thumbnail) image pairs. Within each batch, all samples except one's own positive were used as negative samples, training to minimize the distance between anchor-positive pairs while maximizing the distance between anchor-negative pairs. This method is known as contrastive learning, which is the training method used by OpenAI's CLIP model. |
|
Initially, anchor - positive - negative pairs were explicitly constructed in a 1:1:1 ratio using triplet loss, but training with in-batch negative sampling and contrastive loss showed much better performance as it allowed learning from more negative samples. |
|
|
|
<img src="image_encoder.png" width="500" alt="image_encoder"> |
|
|
|
<img src="contrastive_learning.png" width="500" alt="contrastive_learning"> |
|
|
|
|
|
# 2. Training dataset |
|
|
|
User posting images from onthelook and kream were crawled and preprocessed. First, raw data of image-product thumbnail combinations from posts were collected. Then, object detection was performed on posting images, and category classification was performed on product thumbnails to pair images of the same category together. For thumbnail category classification, a trained category classifier was used. Finally, about 290,000 anchor-positive image pairs were created for 6 categories: tops, bottoms, outer, shoes, bags, and hats. |
|
Finally, approximately 290,000 anchor-positive image pairs were created for 6 categories: tops, bottoms, outer, shoes, bags, and hats. |
|
|
|
|
|
You can find object-detection model -> [https://huggingface.co/yainage90/fashion-object-detection](https://huggingface.co/yainage90/fashion-object-detection) |
|
|
|
You can find details of model in this github repo -> [fashion-visual-search](https://github.com/yainage90/fashion-visual-search) |
|
|
|
|
|
```python |
|
from PIL import Image |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import torchvision.transforms as v2 |
|
from transformers import AutoImageProcessor, SwinModel, SwinConfig |
|
from huggingface_hub import PyTorchModelHubMixin |
|
|
|
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') |
|
|
|
ckpt = "yainage90/fashion-image-feature-extractor" |
|
encoder_config = SwinConfig.from_pretrained(ckpt) |
|
encoder_image_processor = AutoImageProcessor.from_pretrained(ckpt) |
|
|
|
class ImageEncoder(nn.Module, PyTorchModelHubMixin): |
|
def __init__(self): |
|
super(ImageEncoder, self).__init__() |
|
self.swin = SwinModel(config=encoder_config) |
|
self.embedding_layer = nn.Linear(encoder_config.hidden_size, 128) |
|
|
|
def forward(self, image_tensor): |
|
features = self.swin(image_tensor).pooler_output |
|
embeddings = self.embedding_layer(features) |
|
embeddings = F.normalize(embeddings, p=2, dim=1) |
|
|
|
return embeddings |
|
|
|
encoder = ImageEncoder().from_pretrained('yainage90/fashion-image-feature-extractor').to(device) |
|
|
|
transform = v2.Compose([ |
|
v2.Resize((encoder_config.image_size, encoder_config.image_size)), |
|
v2.ToTensor(), |
|
v2.Normalize(mean=encoder_image_processor.image_mean, std=encoder_image_processor.image_std), |
|
]) |
|
|
|
image = Image.open('<path/to/image>').convert('RGB') |
|
image = transform(image) |
|
with torch.no_grad(): |
|
embedding = encoder(image.unsqueeze(0).to(device)).cpu().numpy() |
|
``` |
|
|
|
<img src="detection_image1.png" width="500" alt="detection_image1"> |
|
<img src="result_image1.png" width="700" alt="result_image1"> |
|
<img src="detection_image2.png" width="500" alt="detection_image2"> |
|
<img src="result_image2.png" width="700" alt="result_image2"> |
|
|