Spaces:
Runtime error
Runtime error
File size: 2,984 Bytes
79146ea 91df7cc 79146ea 91df7cc 79146ea 91df7cc 79146ea 91df7cc 79146ea 91df7cc 79146ea 91df7cc 79146ea 91df7cc 79146ea 91df7cc 79146ea 91df7cc 79146ea 91df7cc 79146ea 91df7cc 79146ea 91df7cc 79146ea 91df7cc 79146ea 91df7cc 79146ea |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
from typing import Any
import pytorch_lightning as pl
from torchvision.models import efficientnet_v2_s, EfficientNet_V2_S_Weights
import torch
from torch import nn
from torchvision import transforms
import yaml
from yaml.loader import SafeLoader
import gradio as gr
import os
class WeedModel(pl.LightningModule):
def __init__(self, params):
super().__init__()
self.params = params
model = self.params["model"]
if model.lower() == "efficientnet":
if self.params["pretrained"]:
self.base_model = efficientnet_v2_s(
weights=EfficientNet_V2_S_Weights.IMAGENET1K_V1
)
else:
self.base_model = efficientnet_v2_s(weights=None)
num_ftrs = self.base_model.classifier[-1].in_features
self.base_model.classifier[-1] = nn.Linear(num_ftrs, self.params["n_class"])
else:
print("not prepared model yet!!")
def forward(self, x):
embedding = self.base_model(x)
return embedding
def predict_step(
self, batch: Any, batch_idx: int = 0, dataloader_idx: int = 0
) -> Any:
y_hat = self(batch)
preds = torch.softmax(y_hat, dim=-1).tolist()
# preds = torch.argmax(preds, dim=-1)
return preds
def predict(image):
tensor_image = transform(image)
outs = model.predict_step(tensor_image.unsqueeze(0))
labels = {class_names[k]: float(v) for k, v in enumerate(outs[0][:-1])}
return labels
title = " AISeed AI Application Demo "
description = "# A Demo of Deep Learning for Weed Classification"
example_list = [["examples/" + example] for example in os.listdir("examples")]
with open("class_names.txt", "r", encoding="utf-8") as f:
class_names = f.read().splitlines()
with gr.Blocks() as demo:
demo.title = title
gr.Markdown(description)
with gr.Tabs():
with gr.TabItem("Images"):
with gr.Row():
with gr.Column():
im = gr.Image(type="pil", label="input image", sources=["upload", "webcam"])
with gr.Column():
label_conv = gr.Label(label="Predictions", num_top_classes=4)
btn = gr.Button(value="predict")
btn.click(predict, inputs=im, outputs=[label_conv])
gr.Examples(examples=example_list, inputs=[im], outputs=[label_conv])
if __name__ == "__main__":
with open("config.yaml") as f:
PARAMS = yaml.load(f, Loader=SafeLoader)
print(PARAMS)
model = WeedModel.load_from_checkpoint(
"model/epoch=08.ckpt", params=PARAMS, map_location=torch.device("cpu")
)
model.eval()
transform = transforms.Compose(
[
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
# transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
demo.launch()
|